diff --git a/.gitignore b/.gitignore index 808559f8c..929fa1d4d 100644 --- a/.gitignore +++ b/.gitignore @@ -150,3 +150,6 @@ config.toml desktop.ini .viminfo + +# SAP RFC lib +sap_netweaver_rfc \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b213dd11..f7bc3fb7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,18 +3,45 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - ## [Unreleased] +## [0.3.0] - 2022-02-16 +### Added +- new source `SAPRFC` for connecting with SAP using the `pyRFC` library (requires pyrfc as well as the SAP NW RFC library that can be downloaded [here](https://support.sap.com/en/product/connectors/nwrfcsdk.html) +- new source `DuckDB` for connecting with the `DuckDB` database +- new task `SAPRFCToDF` for loading data from SAP to a pandas DataFrame +- new tasks, `DuckDBQuery` and `DuckDBCreateTableFromParquet`, for interacting with DuckDB +- new flow `SAPToDuckDB` for moving data from SAP to DuckDB +- Added `CheckColumnOrder` task +- C4C connection with url and report_url documentation +-`SQLIteInsert` check if DataFrame is empty or object is not a DataFrame +- KeyVault support in `SharepointToDF` task +- KeyVault support in `CloudForCustomers` tasks + +### Changed +- pinned Prefect version to 0.15.11 +- `df_to_csv` now creates dirs if they don't exist +- `ADLSToAzureSQL` - when data in csv coulmns has unnecessary "\t" then removes them + +### Fixed +- fixed an issue with duckdb calls seeing initial db snapshot instead of the updated state (#282) +- C4C connection with url and report_url optimization +- column mapper in C4C source + ## [0.2.15] - 2022-01-12 ### Added - new option to `ADLSToAzureSQL` Flow - `if_exists="delete"` - `SQL` source: `create_table()` already handles `if_exists`; now it handles a new option for `if_exists()` - `C4CToDF` and `C4CReportToDF` tasks are provided as a class instead of function + ### Fixed - Appending issue within CloudForCustomers source +- An early return bug in `UKCarbonIntensity` in `to_df` method + + ## [0.2.14] - 2021-12-01 + ### Fixed - authorization issue within `CloudForCustomers` source diff --git a/README.md b/README.md index b0a304dff..11862577c 100644 --- a/README.md +++ b/README.md @@ -115,10 +115,20 @@ However, when developing, the easiest way is to use the provided Jupyter Lab con Please follow the standards and best practices used within the library (eg. when adding tasks, see how other tasks are constructed, etc.). For any questions, please reach out to us here on GitHub. - ### Style guidelines - the code should be formatted with Black using default settings (easiest way is to use the VSCode extension) - commit messages should: - begin with an emoji - start with one of the following verbs, capitalized, immediately after the summary emoji: "Added", "Updated", "Removed", "Fixed", "Renamed", and, sporadically, other ones, such as "Upgraded", "Downgraded", or whatever you find relevant for your particular situation - - contain a useful description of what the commit is doing \ No newline at end of file + - contain a useful description of what the commit is doing + +## Set up Black for development in VSCode +Your code should be formatted with Black when you want to contribute. To set up Black in Visual Studio Code follow instructions below. +1. Install `black` in your environment by writing in the terminal: +``` +pip install black +``` +2. Go to the settings - gear icon in the bottom left corner and select `Settings` or type "Ctrl" + ",". +3. Find the `Format On Save` setting - check the box. +4. Find the `Python Formatting Provider` and select "black" in the drop-down list. +5. Your code should auto format on save now. \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 1edf83e34..6eeab2934 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM prefecthq/prefect:latest-python3.8 +FROM prefecthq/prefect:0.15.11-python3.8 # Add user @@ -11,10 +11,12 @@ RUN useradd --create-home viadot && \ RUN groupadd docker && \ usermod -aG docker viadot + # Release File Error # https://stackoverflow.com/questions/63526272/release-file-is-not-valid-yet-docker RUN echo "Acquire::Check-Valid-Until \"false\";\nAcquire::Check-Date \"false\";" | cat > /etc/apt/apt.conf.d/10no--check-valid-until + # System packages RUN apt update && yes | apt install vim unixodbc-dev build-essential \ curl python3-dev libboost-all-dev libpq-dev graphviz python3-gi sudo git @@ -36,7 +38,14 @@ RUN curl https://packages.microsoft.com/keys/microsoft.asc | apt-key add - && \ COPY docker/odbcinst.ini /etc + # Python env + +# This one's needed for the SAP RFC connector. +# It must be installed here as the SAP package does not define its dependencies, +# so `pip install pyrfc` breaks if all deps are not already present. +RUN pip install cython==0.29.24 + WORKDIR /code COPY requirements.txt /code/ RUN pip install --upgrade pip diff --git a/docs/references/api_sources.md b/docs/references/api_sources.md index 27db38553..026d61299 100644 --- a/docs/references/api_sources.md +++ b/docs/references/api_sources.md @@ -2,4 +2,6 @@ ::: viadot.sources.uk_carbon_intensity.UKCarbonIntensity -::: viadot.sources.supermetrics.Supermetrics \ No newline at end of file +::: viadot.sources.supermetrics.Supermetrics + +::: viadot.sources.cloud_for_customers.CloudForCustomers diff --git a/docs/references/flows_library.md b/docs/references/flows_library.md index d16868c6f..9c7bdd8df 100644 --- a/docs/references/flows_library.md +++ b/docs/references/flows_library.md @@ -8,5 +8,4 @@ ::: viadot.flows.azure_sql_transform.AzureSQLTransform ::: viadot.flows.supermetrics_to_adls.SupermetricsToADLS ::: viadot.flows.supermetrics_to_azure_sql.SupermetricsToAzureSQL - - +::: viadot.flows.cloud_for_customers_report_to_adls.CloudForCustomersReportToADLS diff --git a/docs/references/task_library.md b/docs/references/task_library.md index d6817dac4..383e3d143 100644 --- a/docs/references/task_library.md +++ b/docs/references/task_library.md @@ -9,7 +9,6 @@ ::: viadot.tasks.azure_data_lake.AzureDataLakeCopy ::: viadot.tasks.azure_data_lake.AzureDataLakeList - ::: viadot.tasks.azure_key_vault.AzureKeyVaultSecret ::: viadot.tasks.azure_key_vault.CreateAzureKeyVaultSecret ::: viadot.tasks.azure_key_vault.DeleteAzureKeyVaultSecret @@ -29,4 +28,7 @@ ::: viadot.tasks.supermetrics.SupermetricsToDF :::viadot.task_utils.add_ingestion_metadata_task -:::viadot.task_utils.get_latest_timestamp_file_path \ No newline at end of file +:::viadot.task_utils.get_latest_timestamp_file_path + +::: viadot.tasks.cloud_for_customers.C4CToDF +::: viadot.tasks.cloud_for_customers.C4CReportToDF diff --git a/docs/tutorials/sharepoint.md b/docs/tutorials/sharepoint.md new file mode 100644 index 000000000..14939ea2b --- /dev/null +++ b/docs/tutorials/sharepoint.md @@ -0,0 +1,8 @@ +# How to pull excel file from Sharepoint + +With Viadot you can download Excel file from Sharepoint and then upload it to Azure Data Lake. You can set a URL to file on Sharepoint an specify parameters such as path to local Excel file, number of rows and sheet number to be extracted. + +## Pull data from Sharepoint and save output as a csv file on Azure Data Lake + +To pull Excel file from Sharepint we create flow basing on `SharepointToADLS` +:::viadot.flows.SharepointToADLS diff --git a/requirements.txt b/requirements.txt index 380f70164..6074b74dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ mkdocs-material==8.0.1 mkdocs==1.2.3 mkdocstrings==0.16.2 pandas==1.3.4 -prefect[viz]==0.15.5 +prefect[viz]==0.15.11 pyarrow==6.0.1 pyodbc==4.0.32 pytest==6.2.5 @@ -24,4 +24,6 @@ PyGithub==1.55 Shapely==1.8.0 imagehash==4.2.1 visions==0.7.4 -sharepy==1.3.0 \ No newline at end of file +sharepy==1.3.0 +sql-metadata==2.3.0 +duckdb==0.3.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 75f80d77e..02ef25c02 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,11 @@ def get_version(package: str): with open("README.md", "r") as fh: long_description = fh.read() + +extras = { + "sap": ["pyrfc==2.5.0", "sql-metadata==2.3.0"], +} + setuptools.setup( name="viadot", version=get_version("viadot"), @@ -22,6 +27,7 @@ def get_version(package: str): long_description_content_type="text/markdown", url="https://github.com/dyvenia/viadot", packages=setuptools.find_packages(), + extras=extras, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/tests/conftest.py b/tests/conftest.py index 3e5b9327c..28f9fff39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,11 @@ def TEST_PARQUET_FILE_PATH(): return "test_data_countries.parquet" +@pytest.fixture(scope="session") +def TEST_PARQUET_FILE_PATH_2(): + return "test_data_countries_2.parquet" + + @pytest.fixture(scope="session") def TEST_CSV_FILE_BLOB_PATH(): return "tests/test.csv" @@ -44,3 +49,10 @@ def create_test_parquet_file(DF, TEST_PARQUET_FILE_PATH): DF.to_parquet(TEST_PARQUET_FILE_PATH, index=False) yield os.remove(TEST_PARQUET_FILE_PATH) + + +@pytest.fixture(scope="session", autouse=True) +def create_test_parquet_file_2(DF, TEST_PARQUET_FILE_PATH_2): + DF.to_parquet(TEST_PARQUET_FILE_PATH_2, index=False) + yield + os.remove(TEST_PARQUET_FILE_PATH_2) diff --git a/tests/integration/flows/test_adls_to_azure_sql.py b/tests/integration/flows/test_adls_to_azure_sql.py index ff42c141b..ca6c30a25 100644 --- a/tests/integration/flows/test_adls_to_azure_sql.py +++ b/tests/integration/flows/test_adls_to_azure_sql.py @@ -1,4 +1,6 @@ +import pandas as pd from viadot.flows import ADLSToAzureSQL +from viadot.flows.adls_to_azure_sql import df_to_csv_task def test_get_promoted_adls_path_csv_file(): @@ -44,3 +46,12 @@ def test_get_promoted_adls_path_dir_starts_with_slash(): flow = ADLSToAzureSQL(name="test", adls_path=adls_path_dir_starts_with_slash) promoted_path = flow.get_promoted_path(env="conformed") assert promoted_path == "conformed/supermetrics/adls_ga_load_times_fr_test.csv" + + +def test_df_to_csv_task(): + d = {"col1": ["rat", "\tdog"], "col2": ["cat", 4]} + df = pd.DataFrame(data=d) + assert df["col1"].astype(str).str.contains("\t")[1] == True + task = df_to_csv_task + task.run(df, "result.csv") + assert df["col1"].astype(str).str.contains("\t")[1] != True diff --git a/tests/integration/flows/test_sap_to_duckdb.py b/tests/integration/flows/test_sap_to_duckdb.py new file mode 100644 index 000000000..9b6f02167 --- /dev/null +++ b/tests/integration/flows/test_sap_to_duckdb.py @@ -0,0 +1,47 @@ +import os + +from viadot.config import local_config + +try: + import pyrfc +except ModuleNotFoundError: + raise + +from viadot.flows import SAPToDuckDB + +sap_test_creds = local_config.get("SAP").get("TEST") +duckdb_creds = {"database": "test1.duckdb"} + + +def test_sap_to_duckdb(): + flow = SAPToDuckDB( + name="SAPToDuckDB flow test", + query=""" + select + ,CLIENT as client + ,KNUMV as number_of_the_document_condition + ,KPOSN as condition_item_number + ,STUNR as step_number + ,KAPPL as application + from PRCD_ELEMENTS + where KNUMV = '2003393196' + and KPOSN = '000001' + or STUNR = '570' + and CLIENT = '009' + limit 3 + """, + schema="main", + table="test", + local_file_path="local.parquet", + table_if_exists="replace", + sap_credentials=sap_test_creds, + duckdb_credentials=duckdb_creds, + ) + + result = flow.run() + assert result.is_successful() + + task_results = result.result.values() + assert all([task_result.is_successful() for task_result in task_results]) + + os.remove("test1.duckdb") diff --git a/tests/integration/tasks/test_azure_sql.py b/tests/integration/tasks/test_azure_sql.py index 7642c7c86..30eb2f7e4 100644 --- a/tests/integration/tasks/test_azure_sql.py +++ b/tests/integration/tasks/test_azure_sql.py @@ -1,6 +1,9 @@ import logging +import pandas as pd +import pytest +from viadot.exceptions import ValidationError -from viadot.tasks import AzureSQLCreateTable, AzureSQLDBQuery +from viadot.tasks import AzureSQLCreateTable, AzureSQLDBQuery, CheckColumnOrder logger = logging.getLogger(__name__) @@ -63,3 +66,69 @@ def test_azure_sql_run_drop_query(): """ exists = bool(sql_query_task.run(list_table_info_query)) assert not exists + + +def test_check_column_order_append_same_col_number(caplog): + create_table_task = AzureSQLCreateTable() + with caplog.at_level(logging.INFO): + create_table_task.run( + schema=SCHEMA, + table=TABLE, + dtypes={"id": "INT", "name": "VARCHAR(25)", "street": "VARCHAR(25)"}, + if_exists="replace", + ) + assert "Successfully created table sandbox" in caplog.text + + data = {"id": [1], "street": ["Green"], "name": ["Tom"]} + df = pd.DataFrame(data) + + check_column_order = CheckColumnOrder() + with caplog.at_level(logging.WARNING): + check_column_order.run(table=TABLE, if_exists="append", df=df) + + assert ( + "Detected column order difference between the CSV file and the table. Reordering..." + in caplog.text + ) + + +def test_check_column_order_append_diff_col_number(caplog): + create_table_task = AzureSQLCreateTable() + with caplog.at_level(logging.INFO): + create_table_task.run( + schema=SCHEMA, + table=TABLE, + dtypes={"id": "INT", "name": "VARCHAR(25)", "street": "VARCHAR(25)"}, + if_exists="replace", + ) + assert "Successfully created table sandbox" in caplog.text + + data = {"id": [1], "age": ["40"], "street": ["Green"], "name": ["Tom"]} + df = pd.DataFrame(data) + print(f"COMP: \ndf: {df.columns} \nsql: ") + check_column_order = CheckColumnOrder() + with pytest.raises( + ValidationError, + match=r"Detected discrepancies in number of columns or different column names between the CSV file and the SQL table!", + ): + check_column_order.run(table=TABLE, if_exists="append", df=df) + + +def test_check_column_order_replace(caplog): + create_table_task = AzureSQLCreateTable() + with caplog.at_level(logging.INFO): + create_table_task.run( + schema=SCHEMA, + table=TABLE, + dtypes={"id": "INT", "name": "VARCHAR(25)", "street": "VARCHAR(25)"}, + if_exists="replace", + ) + assert "Successfully created table sandbox" in caplog.text + + data = {"id": [1], "street": ["Green"], "name": ["Tom"]} + df = pd.DataFrame(data) + + check_column_order = CheckColumnOrder() + with caplog.at_level(logging.INFO): + check_column_order.run(table=TABLE, if_exists="replace", df=df) + assert "The table will be replaced." in caplog.text diff --git a/tests/integration/tasks/test_cloud_for_customers.py b/tests/integration/tasks/test_cloud_for_customers.py index e5bb0d93a..50a6a1b8c 100644 --- a/tests/integration/tasks/test_cloud_for_customers.py +++ b/tests/integration/tasks/test_cloud_for_customers.py @@ -1,5 +1,6 @@ from viadot.tasks import C4CToDF, C4CReportToDF from viadot.config import local_config +from prefect.tasks.secrets import PrefectSecret def test_c4c_to_df(): @@ -8,7 +9,6 @@ def test_c4c_to_df(): c4c_to_df = C4CToDF() df = c4c_to_df.run(url=url, endpoint=endpoint) answer = df.head() - assert answer.shape[1] == 23 @@ -21,3 +21,11 @@ def test_c4c_report_to_df(): answer = df.head() assert answer.shape[0] == 5 + + +def test_c4c_to_df_kv(): + task = C4CToDF() + credentials_secret = PrefectSecret("C4C_KV").run() + res = task.run(credentials_secret=credentials_secret, endpoint="ActivityCollection") + answer = res.head() + assert answer.shape[1] == 19 diff --git a/tests/integration/tasks/test_sqlite_insert.py b/tests/integration/tasks/test_sqlite_insert.py new file mode 100644 index 000000000..379f12249 --- /dev/null +++ b/tests/integration/tasks/test_sqlite_insert.py @@ -0,0 +1,105 @@ +import sqlite3 +import os +import pytest +import pandas as pd + +from viadot.tasks.sqlite import SQLiteInsert, SQLiteSQLtoDF, SQLiteQuery + +TABLE = "test" + + +@pytest.fixture(scope="session") +def sqlite_insert(): + task = SQLiteInsert() + yield task + + +def test_sqlite_insert_proper(sqlite_insert): + dtypes = {"AA": "INT", "BB": "INT"} + df2 = pd.DataFrame({"AA": [1, 2, 3], "BB": [11, 22, 33]}) + sqlite_insert.run( + table_name=TABLE, + df=df2, + dtypes=dtypes, + if_exists="skip", + db_path="testdb.sqlite", + ) + + with sqlite3.connect("testdb.sqlite") as db: + cursor = db.cursor() + cursor.execute("""SELECT COUNT(*) from test """) + result = cursor.fetchall() + assert result[0][0] != 0 + os.remove("testdb.sqlite") + + +def test_sqlite_insert_empty(caplog, sqlite_insert): + df = pd.DataFrame() + dtypes = {"AA": "INT", "BB": "INT"} + sqlite_insert.run( + table_name=TABLE, + df=df, + dtypes=dtypes, + if_exists="skip", + db_path="testdb1.sqlite", + ) + + with sqlite3.connect("testdb1.sqlite") as db: + cursor2 = db.cursor() + cursor2.execute("""SELECT COUNT(*) from test """) + result2 = cursor2.fetchall() + assert result2[0][0] == 0 + + assert "DataFrame is empty" in caplog.text + os.remove("testdb1.sqlite") + + +def test_sqlite_insert_not(caplog, sqlite_insert): + dtypes = {"AA": "INT", "BB": "INT"} + not_df = [] + sqlite_insert.run( + table_name=TABLE, + df=not_df, + dtypes=dtypes, + if_exists="skip", + db_path="testdb2.sqlite", + ) + with sqlite3.connect("testdb2.sqlite") as db: + cursor3 = db.cursor() + cursor3.execute("""SELECT COUNT(*) from test """) + result3 = cursor3.fetchall() + assert result3[0][0] == 0 + + assert "Object is not a pandas DataFrame" in caplog.text + os.remove("testdb2.sqlite") + + +def test_sqlite_sql_to_df(sqlite_insert): + task = SQLiteSQLtoDF() + with sqlite3.connect("testdb3.sqlite") as db: + cursor = db.cursor() + cursor.execute("""CREATE TABLE IF NOT EXISTS test ([AA] INT, [BB] INT) """) + cursor.execute("""INSERT INTO test VALUES (11,22), (11,33)""") + + script = "SELECT * FROM test" + with open("testscript.sql", "w") as file: + file.write(script) + df = task.run(db_path="testdb3.sqlite", sql_path="testscript.sql") + assert isinstance(df, pd.DataFrame) == True + expected = pd.DataFrame({"AA": [11, 11], "BB": [22, 33]}) + assert df.equals(expected) == True + os.remove("testdb3.sqlite") + os.remove("testscript.sql") + + +def test_sqlite_to_query(sqlite_insert): + with sqlite3.connect("testdb4.sqlite") as db: + cursor = db.cursor() + cursor.execute("""CREATE TABLE IF NOT EXISTS test ([AA] INT, [BB] INT) """) + query = "INSERT INTO test VALUES (11,22), (11,33)" + task = SQLiteQuery() + task.run(query=query, db_path="testdb4.sqlite") + cursor.execute("""SELECT COUNT(*) from test """) + result = cursor.fetchall() + assert result[0][0] != 0 + os.remove("testdb4.sqlite") diff --git a/tests/integration/test_sap_rfc.py b/tests/integration/test_sap_rfc.py new file mode 100644 index 000000000..ddae30e76 --- /dev/null +++ b/tests/integration/test_sap_rfc.py @@ -0,0 +1,104 @@ +from viadot.sources import SAPRFC +from collections import OrderedDict + +sap = SAPRFC() + +sql1 = "SELECT a AS a_renamed, b FROM table1 WHERE table1.c = 1" +sql2 = "SELECT a FROM fake_schema.fake_table WHERE a=1 AND b=2 OR c LIKE 'a%' AND d IN (1, 2) LIMIT 5 OFFSET 3" +sql3 = "SELECT b FROM c WHERE testORword=1 AND testANDword=2 AND testLIMITword=3 AND testOFFSETword=4" +sql4 = "SELECT c FROM d WHERE testLIMIT = 1 AND testOFFSET = 2 AND LIMITtest=3 AND OFFSETtest=4" +sql5 = sql3 + " AND longword123=5" +sql6 = "SELECT a FROM fake_schema.fake_table WHERE a=1 AND b=2 OR c LIKE 'a%' AND d IN (1, 2) AND longcolname=3 AND otherlongcolname=5 LIMIT 5 OFFSET 3" +sql7 = """ +SELECT a, b +FROM b +WHERE c = 1 +AND d = 2 +AND longcolname = 12345 +AND otherlongcolname = 6789 +AND thirdlongcolname = 01234 +LIMIT 5 +OFFSET 10 +""" + + +def test__get_table_name(): + assert sap._get_table_name(sql1) == "table1" + assert sap._get_table_name(sql2) == "fake_schema.fake_table", sap._get_table_name( + sql2 + ) + assert sap._get_table_name(sql7) == "b" + + +def test__get_columns(): + assert sap._get_columns(sql1) == ["a", "b"] + assert sap._get_columns(sql1, aliased=True) == ["a_renamed", "b"], sap._get_columns( + sql1, aliased=True + ) + assert sap._get_columns(sql2) == ["a"] + assert sap._get_columns(sql7) == ["a", "b"] + + +def test__get_where_condition(): + assert sap._get_where_condition(sql1) == "table1.c = 1", sap._get_where_condition( + sql1 + ) + assert ( + sap._get_where_condition(sql2) == "a=1 AND b=2 OR c LIKE 'a%' AND d IN (1, 2)" + ), sap._get_where_condition(sql2) + assert ( + sap._get_where_condition(sql3) + == "testORword=1 AND testANDword=2 AND testLIMITword=3 AND testOFFSETword=4" + ), sap._get_where_condition(sql3) + assert ( + sap._get_where_condition(sql4) + == "testLIMIT = 1 AND testOFFSET = 2 AND LIMITtest=3 AND OFFSETtest=4" + ), sap._get_where_condition(sql4) + assert ( + sap._get_where_condition(sql7) + == "c = 1 AND d = 2 AND longcolname = 12345 AND otherlongcolname = 6789" + ), sap._get_where_condition(sql7) + + +def test__get_limit(): + assert sap._get_limit(sql1) is None + assert sap._get_limit(sql2) == 5 + assert sap._get_limit(sql7) == 5 + + +def test__get_offset(): + assert sap._get_offset(sql1) is None + assert sap._get_offset(sql2) == 3 + assert sap._get_offset(sql7) == 10 + + +def test_client_side_filters_simple(): + _ = sap._get_where_condition(sql5) + assert sap.client_side_filters == OrderedDict( + {"AND": "longword123=5"} + ), sap.client_side_filters + + +def test_client_side_filters_with_limit_offset(): + _ = sap._get_where_condition(sql6) + assert sap.client_side_filters == OrderedDict( + {"AND": "otherlongcolname=5"} + ), sap.client_side_filters + + _ = sap._get_where_condition(sql7) + assert sap.client_side_filters == OrderedDict( + {"AND": "thirdlongcolname = 01234"} + ), sap.client_side_filters + + +def test___build_pandas_filter_query(): + _ = sap._get_where_condition(sql6) + assert ( + sap._build_pandas_filter_query(sap.client_side_filters) + == "otherlongcolname == 5" + ), sap._build_pandas_filter_query(sap.client_side_filters) + _ = sap._get_where_condition(sql7) + assert ( + sap._build_pandas_filter_query(sap.client_side_filters) + == "thirdlongcolname == 01234" + ), sap._build_pandas_filter_query(sap.client_side_filters) diff --git a/tests/integration/test_sharepoint.py b/tests/integration/test_sharepoint.py index 81048f90f..d158a0768 100644 --- a/tests/integration/test_sharepoint.py +++ b/tests/integration/test_sharepoint.py @@ -1,12 +1,23 @@ import pytest import os import pathlib +import json import pandas as pd +import configparser from viadot.exceptions import CredentialError from viadot.sources import Sharepoint from viadot.config import local_config from viadot.task_utils import df_get_data_types_task +from viadot.tasks.sharepoint import SharepointToDF + +from prefect.tasks.secrets import PrefectSecret + + +def get_url(): + with open(".config/credentials.json", "r") as f: + config = json.load(f) + return config["SHAREPOINT"]["url"] @pytest.fixture(scope="session") @@ -17,8 +28,8 @@ def sharepoint(): @pytest.fixture(scope="session") def FILE_NAME(sharepoint): - path = "EUL Data.xlsm" - sharepoint.download_file(download_to_path=path) + path = "Questionnaires.xlsx" + sharepoint.download_file(download_to_path=path, download_from_path=get_url()) yield path os.remove(path) @@ -38,6 +49,19 @@ def test_connection(sharepoint): assert response.status_code == 200 +def test_sharepoint_to_df_task(): + task = SharepointToDF() + credentials_secret = PrefectSecret("SHAREPOINT_KV").run() + res = task.run( + credentials_secret=credentials_secret, + sheet_number=0, + path_to_file="Questionnaires.xlsx", + url_to_file=get_url(), + ) + assert isinstance(res, pd.DataFrame) + os.remove("Questionnaires.xlsx") + + def test_file_download(FILE_NAME): files = [] for file in os.listdir(): @@ -47,12 +71,12 @@ def test_file_download(FILE_NAME): def test_autopopulating_download_from(FILE_NAME): - assert os.path.basename(sharepoint.download_from_path) == FILE_NAME + assert os.path.basename(get_url()) == FILE_NAME def test_file_extension(sharepoint): - file_ext = [".xlsm", ".xlsx"] - assert pathlib.Path(sharepoint.download_from_path).suffix in file_ext + file_ext = (".xlsm", ".xlsx") + assert get_url().endswith(file_ext) def test_file_to_df(FILE_NAME): diff --git a/tests/integration/test_uk_carbon_intensity.py b/tests/integration/test_uk_carbon_intensity.py index 5ff8ffced..c6dc7e72f 100644 --- a/tests/integration/test_uk_carbon_intensity.py +++ b/tests/integration/test_uk_carbon_intensity.py @@ -53,3 +53,9 @@ def test_stats_to_csv(carbon): carbon.query(f"/intensity/stats/{from_.isoformat()}/{to.isoformat()}") carbon.to_csv(TEST_FILE_2, if_exists="append") assert os.path.isfile(TEST_FILE_2) == True + + +def test_to_df_today(carbon): + carbon.query("/intensity/date") + df = carbon.to_df() + assert len(df) > 1 diff --git a/tests/test_viadot.py b/tests/test_viadot.py index bdf505e78..67195838c 100644 --- a/tests/test_viadot.py +++ b/tests/test_viadot.py @@ -2,4 +2,4 @@ def test_version(): - assert __version__ == "0.2.15" + assert __version__ == "0.3.0" diff --git a/tests/unit/test_base.py b/tests/unit/test_base.py index 71bdd631b..a658082da 100644 --- a/tests/unit/test_base.py +++ b/tests/unit/test_base.py @@ -1,7 +1,9 @@ import os - +import pytest import pandas as pd +import pyarrow as pa from viadot.sources.base import SQL, Source +from viadot.signals import SKIP from .test_credentials import get_credentials @@ -10,6 +12,14 @@ PATH = "t.csv" +class NotEmptySource(Source): + def to_df(self, if_empty): + df = pd.DataFrame.from_dict( + data={"country": ["italy", "germany", "spain"], "sales": [100, 50, 80]} + ) + return df + + class EmptySource(Source): def to_df(self, if_empty): df = pd.DataFrame() @@ -24,6 +34,38 @@ def test_empty_source_skip(): assert result is False +def test_to_csv(): + src = NotEmptySource() + res = src.to_csv(path="testbase.csv") + assert res == True + assert os.path.isfile("testbase.csv") == True + os.remove("testbase.csv") + + +def test_to_arrow(): + src = NotEmptySource() + res = src.to_arrow("testbase.arrow") + assert isinstance(res, pa.Table) == True + + +def test_to_excel(): + src = NotEmptySource() + res = src.to_excel(path="testbase.xlsx") + assert res == True + assert os.path.isfile("testbase.xlsx") == True + os.remove("testbase.xlsx") + + +def test_handle_if_empty(caplog): + src = EmptySource() + src._handle_if_empty(if_empty="warn") + assert "WARNING The query produced no data." in caplog.text + with pytest.raises(ValueError): + src._handle_if_empty(if_empty="fail") + with pytest.raises(SKIP): + src._handle_if_empty(if_empty="skip") + + # def test_to_csv_append(): # """Test whether `to_csv()` with the append option writes data of correct shape""" # driver = "/usr/lib/x86_64-linux-gnu/odbc/libsqlite3odbc.so" diff --git a/tests/unit/test_duckdb.py b/tests/unit/test_duckdb.py new file mode 100644 index 000000000..05b1778da --- /dev/null +++ b/tests/unit/test_duckdb.py @@ -0,0 +1,63 @@ +import os + +import pytest +from viadot.sources.duckdb import DuckDB +import os + +TABLE = "test_table" +SCHEMA = "test_schema" +TABLE_MULTIPLE_PARQUETS = "test_multiple_parquets" +DATABASE_PATH = "test.duckdb" + + +@pytest.fixture(scope="session") +def duckdb(): + try: + os.remove(DATABASE_PATH) + except FileNotFoundError: + pass + duckdb = DuckDB(credentials=dict(database=DATABASE_PATH)) + yield duckdb + os.remove(DATABASE_PATH) + + +def test__check_if_schema_exists(duckdb): + + duckdb.run(f"DROP SCHEMA IF EXISTS {SCHEMA}") + assert not duckdb._check_if_schema_exists(SCHEMA) + + duckdb.run(f"CREATE SCHEMA {SCHEMA}") + assert not duckdb._check_if_schema_exists(SCHEMA) + + duckdb.run(f"DROP SCHEMA {SCHEMA}") + + +def test_create_table_from_parquet(duckdb, TEST_PARQUET_FILE_PATH): + duckdb.create_table_from_parquet( + schema=SCHEMA, table=TABLE, path=TEST_PARQUET_FILE_PATH + ) + df = duckdb.to_df(f"SELECT * FROM {SCHEMA}.{TABLE}") + assert df.shape[0] == 3 + duckdb.drop_table(TABLE, schema=SCHEMA) + duckdb.run(f"DROP SCHEMA {SCHEMA}") + + +def test_create_table_from_multiple_parquet(duckdb): + # we use the two Parquet files generated by fixtures in conftest + duckdb.create_table_from_parquet( + schema=SCHEMA, table=TABLE_MULTIPLE_PARQUETS, path="*.parquet" + ) + df = duckdb.to_df(f"SELECT * FROM {SCHEMA}.{TABLE_MULTIPLE_PARQUETS}") + assert df.shape[0] == 6 + duckdb.drop_table(TABLE_MULTIPLE_PARQUETS, schema=SCHEMA) + duckdb.run(f"DROP SCHEMA {SCHEMA}") + + +def test__check_if_table_exists(duckdb, TEST_PARQUET_FILE_PATH): + + assert not duckdb._check_if_table_exists(table=TABLE, schema=SCHEMA) + duckdb.create_table_from_parquet( + schema=SCHEMA, table=TABLE, path=TEST_PARQUET_FILE_PATH + ) + assert duckdb._check_if_table_exists(TABLE, schema=SCHEMA) + duckdb.drop_table(TABLE, schema=SCHEMA) diff --git a/tests/unit/test_sqlite.py b/tests/unit/test_sqlite.py index 89bbf7c8e..5b0831273 100644 --- a/tests/unit/test_sqlite.py +++ b/tests/unit/test_sqlite.py @@ -40,3 +40,10 @@ def test_insert_into_sql(sqlite, DF): results = sqlite.run(f"SELECT * FROM {TABLE}") df = pandas.DataFrame.from_records(results, columns=["country", "sales"]) assert df["sales"].sum() == 230 + + +def test_check_if_table_exists(sqlite): + exists = sqlite._check_if_table_exists(TABLE) + assert exists == True + not_exists = sqlite._check_if_table_exists("test_table") + assert not_exists == False diff --git a/tests/unit/test_task_utils.py b/tests/unit/test_task_utils.py index f4bc977ee..aea910957 100644 --- a/tests/unit/test_task_utils.py +++ b/tests/unit/test_task_utils.py @@ -1,8 +1,19 @@ import pytest +import numpy as np +import os import pandas as pd from typing import List -from viadot.task_utils import df_get_data_types_task, df_map_mixed_dtypes_for_parquet +from viadot.task_utils import ( + chunk_df, + df_get_data_types_task, + df_map_mixed_dtypes_for_parquet, + df_to_csv, + df_to_parquet, + union_dfs_task, + dtypes_to_json, + write_to_json, +) def count_dtypes(dtypes_dict: dict = None, dtypes_to_count: List[str] = None) -> int: @@ -32,3 +43,104 @@ def test_map_dtypes_for_parquet(): sum_of_mapped_dtypes = count_dtypes(dtyps_dict_mapped, ["String"]) assert sum_of_dtypes == sum_of_mapped_dtypes + + +def test_chunk_df(): + df = pd.DataFrame( + { + "AA": [1, 2, 3, 4, 5], + "BB": [11, 22, 33, 4, 5], + "CC": [4, 5, 6, 4, 5], + "DD": [44, 55, 66, 1, 2], + } + ) + res = chunk_df.run(df=df, size=2) + assert len(res) == 3 + + +def test_df_get_data_types_task(): + df = pd.DataFrame( + { + "a": {0: "ann", 1: "test", 2: "Hello"}, + "b": {0: 9, 1: "2021-01-01", 2: "Hello"}, + "w": {0: 679, 1: "Hello", 2: "Hello"}, + "x": {0: -1, 1: 2, 2: 444}, + "y": {0: 1.5, 1: 11.97, 2: 56.999}, + "z": {0: "2022-01-01", 1: "2021-11-01", 2: "2021-01-01"}, + } + ) + res = df_get_data_types_task.run(df) + assert res == { + "a": "String", + "b": "Object", + "w": "Object", + "x": "Integer", + "y": "Float", + "z": "Date", + } + + +def test_df_to_csv(): + df = pd.DataFrame( + { + "a": {0: "a", 1: "b", 2: "c"}, + "b": {0: "a", 1: "b", 2: "c"}, + "w": {0: "a", 1: "b", 2: "c"}, + } + ) + + df_to_csv.run(df, "test.csv") + result = pd.read_csv("test.csv", sep="\t") + assert df.equals(result) + os.remove("test.csv") + + +def test_df_to_parquet(): + df = pd.DataFrame( + { + "a": {0: "a", 1: "b", 2: "c"}, + "b": {0: "a", 1: "b", 2: "c"}, + "w": {0: "a", 1: "b", 2: "c"}, + } + ) + + df_to_parquet.run(df, "test.parquet") + result = pd.read_parquet("test.parquet") + assert df.equals(result) + os.remove("test.parquet") + + +def test_union_dfs_task(): + df1 = pd.DataFrame( + { + "a": {0: "a", 1: "b", 2: "c"}, + "b": {0: "a", 1: "b", 2: "c"}, + "w": {0: "a", 1: "b", 2: "c"}, + } + ) + df2 = pd.DataFrame( + { + "a": {0: "d", 1: "e"}, + "b": {0: "d", 1: "e"}, + } + ) + list_dfs = [] + list_dfs.append(df1) + list_dfs.append(df2) + res = union_dfs_task.run(list_dfs) + assert isinstance(res, pd.DataFrame) + assert len(res) == 5 + + +def test_dtypes_to_json(): + dtypes = {"country": "VARCHAR(100)", "sales": "FLOAT(24)"} + dtypes_to_json.run(dtypes_dict=dtypes, local_json_path="dtypes.json") + assert os.path.exists("dtypes.json") + os.remove("dtypes.json") + + +def test_write_to_json(): + dict = {"name": "John", 1: [2, 4, 3]} + write_to_json.run(dict, "dict.json") + assert os.path.exists("dict.json") + os.remove("dict.json") diff --git a/viadot/__init__.py b/viadot/__init__.py index ddc77a880..493f7415d 100644 --- a/viadot/__init__.py +++ b/viadot/__init__.py @@ -1 +1 @@ -__version__ = "0.2.15" +__version__ = "0.3.0" diff --git a/viadot/examples/hello_world.py b/viadot/examples/hello_world.py index 3fb6c5a7e..ed6b7af71 100644 --- a/viadot/examples/hello_world.py +++ b/viadot/examples/hello_world.py @@ -41,7 +41,7 @@ def say_bye(): git_token_secret_name="github_token", # name of the Prefect secret with the GitHub token ) RUN_CONFIG = DockerRun( - image="prefecthq/prefect", + image="prefecthq/prefect:0.15.11-python3.8", env={"SOME_VAR": "value"}, labels=["dev"], ) diff --git a/viadot/examples/sap_rfc/Dockerfile b/viadot/examples/sap_rfc/Dockerfile new file mode 100644 index 000000000..2832a3ecd --- /dev/null +++ b/viadot/examples/sap_rfc/Dockerfile @@ -0,0 +1,15 @@ +FROM viadot:latest + +USER root + +COPY sap_netweaver_rfc/nwrfcsdk /usr/local/sap/nwrfcsdk +COPY sap_netweaver_rfc/nwrfcsdk.conf /etc/ld.so.conf.d/nwrfcsdk.conf + +ENV SAPNWRFC_HOME=/usr/local/sap/nwrfcsdk + +RUN ldconfig + +COPY requirements.txt . +RUN xargs -L 1 pip install < requirements.txt + +USER viadot \ No newline at end of file diff --git a/viadot/examples/sap_rfc/README.md b/viadot/examples/sap_rfc/README.md new file mode 100644 index 000000000..46e584e04 --- /dev/null +++ b/viadot/examples/sap_rfc/README.md @@ -0,0 +1,11 @@ +## SAP RFC example + +This is an example environment for running the `SAPRFC` connector. + + +Note that we refer to a `sap_netweaver_rfc` folder in the Dockerfile. This is the folder containing the proprietary SAP NetWeaver driver that would have to be obtained and installed by the user. + +### Running SAPRFC +To build the image, run `docker build . -t viadot:sap_rfc`, and spin it up with the provided `docker-compose`: `docker-compose up -d`. You can now open up Jupyter Lab at `localhost:5678`. + +To run tests, run eg. `docker exec -it viadot_saprfc_lab pytest tests/integration/test_sap_rfc.py`. \ No newline at end of file diff --git a/viadot/examples/sap_rfc/docker-compose.yml b/viadot/examples/sap_rfc/docker-compose.yml new file mode 100644 index 000000000..79966472c --- /dev/null +++ b/viadot/examples/sap_rfc/docker-compose.yml @@ -0,0 +1,12 @@ +version: "3" + +services: + viadot_saprfc_lab: + image: viadot:sap_rfc + container_name: viadot_saprfc_lab + ports: + - 5678:8888 + volumes: + - ../../../:/home/viadot + command: jupyter lab --no-browser --ip 0.0.0.0 --LabApp.token='' + restart: "unless-stopped" \ No newline at end of file diff --git a/viadot/examples/sap_rfc/requirements.txt b/viadot/examples/sap_rfc/requirements.txt new file mode 100644 index 000000000..3e0342e61 --- /dev/null +++ b/viadot/examples/sap_rfc/requirements.txt @@ -0,0 +1,3 @@ +cython==0.29.24 +pyrfc==2.5.0 +sql-metadata==2.3.0 \ No newline at end of file diff --git a/viadot/flows/__init__.py b/viadot/flows/__init__.py index 122016a08..696c7aeec 100644 --- a/viadot/flows/__init__.py +++ b/viadot/flows/__init__.py @@ -9,3 +9,8 @@ from .adls_container_to_container import ADLSContainerToContainer from .sharepoint_to_adls import SharepointToADLS from .cloud_for_customers_report_to_adls import CloudForCustomersReportToADLS + +try: + from .sap_to_duckdb import SAPToDuckDB +except ImportError: + pass diff --git a/viadot/flows/adls_to_azure_sql.py b/viadot/flows/adls_to_azure_sql.py index f01a6c6ab..37b74cd70 100644 --- a/viadot/flows/adls_to_azure_sql.py +++ b/viadot/flows/adls_to_azure_sql.py @@ -17,6 +17,8 @@ AzureSQLCreateTable, BCPTask, DownloadGitHubFile, + AzureSQLDBQuery, + CheckColumnOrder, ) logger = logging.get_logger(__name__) @@ -24,10 +26,12 @@ lake_to_df_task = AzureDataLakeToDF() download_json_file_task = AzureDataLakeDownload() download_github_file_task = DownloadGitHubFile() -promote_to_conformed_task = AzureDataLakeUpload() +promote_to_conformed_task = AzureDataLakeCopy() promote_to_operations_task = AzureDataLakeCopy() create_table_task = AzureSQLCreateTable() bulk_insert_task = BCPTask() +azure_query_task = AzureSQLDBQuery() +check_column_order_task = CheckColumnOrder() @task @@ -71,6 +75,10 @@ def map_data_types_task(json_shema_path: str): @task def df_to_csv_task(df, path: str, sep: str = "\t"): + for col in range(len(df.columns)): + df[df.columns[col]] = ( + df[df.columns[col]].astype(str).str.replace(r"\t", "", regex=True) + ) df.to_csv(path, sep=sep, index=False) @@ -215,14 +223,24 @@ def gen_flow(self) -> Flow: else: dtypes = self.dtypes + df_reorder = check_column_order_task.bind( + table=self.table, + df=df, + if_exists=self.if_exists, + credentials_secret=self.sqldb_credentials_secret, + flow=self, + ) + df_to_csv = df_to_csv_task.bind( - df=df, path=self.local_file_path, sep=self.write_sep, flow=self + df=df_reorder, + path=self.local_file_path, + sep=self.write_sep, + flow=self, ) promote_to_conformed_task.bind( from_path=self.local_file_path, to_path=self.adls_path_conformed, - overwrite=self.overwrite_adls, sp_credentials_secret=self.adls_sp_credentials_secret, vault_name=self.vault_name, flow=self, @@ -252,10 +270,10 @@ def gen_flow(self) -> Flow: flow=self, ) - # dtypes.set_upstream(download_json_file_task, flow=self) + df_reorder.set_upstream(lake_to_df_task, flow=self) + df_to_csv.set_upstream(df_reorder, flow=self) promote_to_conformed_task.set_upstream(df_to_csv, flow=self) promote_to_conformed_task.set_upstream(df_to_csv, flow=self) - # map_data_types_task.set_upstream(download_json_file_task, flow=self) create_table_task.set_upstream(df_to_csv, flow=self) promote_to_operations_task.set_upstream(promote_to_conformed_task, flow=self) bulk_insert_task.set_upstream(create_table_task, flow=self) diff --git a/viadot/flows/sap_to_duckdb.py b/viadot/flows/sap_to_duckdb.py new file mode 100644 index 000000000..e50e1817d --- /dev/null +++ b/viadot/flows/sap_to_duckdb.py @@ -0,0 +1,99 @@ +from typing import Any, Dict, List, Literal +from prefect import Flow +from prefect.utilities import logging + +logger = logging.get_logger() + +from ..task_utils import ( + add_ingestion_metadata_task, + df_to_parquet, +) +from ..tasks import SAPRFCToDF, DuckDBCreateTableFromParquet + + +class SAPToDuckDB(Flow): + def __init__( + self, + query: str, + table: str, + local_file_path: str, + name: str = None, + sep: str = "\t", + autopick_sep: bool = True, + schema: str = None, + table_if_exists: Literal["fail", "replace", "skip", "delete"] = "fail", + sap_credentials: dict = None, + duckdb_credentials: dict = None, + *args: List[any], + **kwargs: Dict[str, Any], + ): + """A flow for moving data from SAP to DuckDB. + + Args: + query (str): The query to be executed on SAP with pyRFC. + table (str): Destination table in DuckDB. + local_file_path (str): The path to the source Parquet file. + name (str, optional): The name of the flow. Defaults to None. + sep (str, optional): The separator to use when reading query results. Defaults to "\t". + autopick_sep (bool, optional): Whether SAPRFC should try different separators + in case the query fails with the default one. Defaults to True. + schema (str, optional): Destination schema in DuckDB. Defaults to None. + table_if_exists (Literal, optional): What to do if the table already exists. Defaults to "fail". + sap_credentials (dict, optional): The credentials to use to authenticate with SAP. + By default, they're taken from the local viadot config. + duckdb_credentials (dict, optional): The config to use for connecting with DuckDB. Defaults to None. + """ + + # SAPRFCToDF + self.query = query + self.sep = sep + self.autopick_sep = autopick_sep + self.sap_credentials = sap_credentials + + # DuckDBCreateTableFromParquet + self.table = table + self.schema = schema + self.if_exists = table_if_exists + self.local_file_path = local_file_path or self.slugify(name) + ".parquet" + self.duckdb_credentials = duckdb_credentials + + super().__init__(*args, name=name, **kwargs) + + self.sap_to_df_task = SAPRFCToDF(credentials=sap_credentials) + self.create_duckdb_table_task = DuckDBCreateTableFromParquet( + credentials=duckdb_credentials + ) + + self.gen_flow() + + def gen_flow(self) -> Flow: + + df = self.sap_to_df_task.bind( + query=self.query, + sep=self.sep, + autopick_sep=self.autopick_sep, + flow=self, + ) + + df_with_metadata = add_ingestion_metadata_task.bind(df, flow=self) + + parquet = df_to_parquet.bind( + df=df_with_metadata, + path=self.local_file_path, + if_exists=self.if_exists, + flow=self, + ) + + table = self.create_duckdb_table_task.bind( + path=self.local_file_path, + schema=self.schema, + table=self.table, + if_exists=self.if_exists, + flow=self, + ) + + table.set_upstream(parquet, flow=self) + + @staticmethod + def slugify(name): + return name.replace(" ", "_").lower() diff --git a/viadot/sources/__init__.py b/viadot/sources/__init__.py index 177d13485..a4d9d1204 100644 --- a/viadot/sources/__init__.py +++ b/viadot/sources/__init__.py @@ -5,6 +5,12 @@ from .cloud_for_customers import CloudForCustomers from .sharepoint import Sharepoint +try: + from .sap_rfc import SAPRFC +except ImportError: + pass + # APIS from .uk_carbon_intensity import UKCarbonIntensity from .sqlite import SQLite +from .duckdb import DuckDB diff --git a/viadot/sources/azure_data_lake.py b/viadot/sources/azure_data_lake.py index d55a003cc..804a448b5 100644 --- a/viadot/sources/azure_data_lake.py +++ b/viadot/sources/azure_data_lake.py @@ -174,14 +174,32 @@ def to_df( return df def ls(self, path: str = None) -> List[str]: + """Returns list of files in a path. + + Args: + path (str, optional): Path to a folder. Defaults to None. + """ path = path or self.path return self.fs.ls(path) def rm(self, path: str = None, recursive: bool = False): + """Deletes files in a path. + + Args: + path (str, optional): Path to a folder. Defaults to None. + recursive (bool, optional): Whether to delete files recursively or not. Defaults to False. + """ path = path or self.path self.fs.rm(path, recursive=recursive) def cp(self, from_path: str = None, to_path: str = None, recursive: bool = False): + """Copies source to a destination. + + Args: + from_path (str, optional): Path form which to copy file. Defauls to None. + to_path (str, optional): Path where to copy files. Defaults to None. + recursive (bool, optional): Whether to copy files recursively or not. Defaults to False. + """ from_path = from_path or self.path to_path = to_path self.fs.cp(from_path, to_path, recursive=recursive) diff --git a/viadot/sources/azure_sql.py b/viadot/sources/azure_sql.py index afaefbae3..64a98ea99 100644 --- a/viadot/sources/azure_sql.py +++ b/viadot/sources/azure_sql.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Literal +from typing import List, Literal from prefect.utilities import logging @@ -8,6 +8,8 @@ class AzureSQL(SQL): + DEFAULT_SCHEMA = "dbo" + def __init__( self, *args, @@ -18,11 +20,13 @@ def __init__( @property def schemas(self) -> List[str]: + """Returns list of schemas""" schemas_tuples = self.run("SELECT s.name as schema_name from sys.schemas s") return [schema_tuple[0] for schema_tuple in schemas_tuples] @property def tables(self) -> List[str]: + """Returns list of tables""" tables_tuples = self.run("SELECT * FROM information_schema.tables") return [table for row in tables_tuples for table in row] @@ -34,6 +38,14 @@ def bulk_insert( sep="\t", if_exists: Literal = "append", ): + """Fuction to bulk insert. + Args: + table (str): Table name. + schema (str, optional): Schema name. Defaults to None. + source_path (str, optional): Full path to a data file. Defaults to one. + sep (str, optional): field terminator to be used for char and widechar data files. Defaults to "\t". + if_exists (Literal, optional): What to do if the table already exists. Defaults to "append". + """ if schema is None: schema = self.DEFAULT_SCHEMA fqn = f"{schema}.{table}" @@ -121,7 +133,7 @@ def exists(self, table: str, schema: str = None) -> bool: """ if not schema: - schema = "dbo" + schema = self.DEFAULT_SCHEMA list_table_info_query = f""" SELECT * diff --git a/viadot/sources/base.py b/viadot/sources/base.py index 9b9c5a99a..6dc39160f 100644 --- a/viadot/sources/base.py +++ b/viadot/sources/base.py @@ -1,6 +1,6 @@ import os from abc import abstractmethod -from typing import Any, Dict, List, Literal, NoReturn, Tuple +from typing import Any, Dict, List, Literal, NoReturn, Tuple, Union import pandas as pd import pyarrow as pa @@ -34,6 +34,11 @@ def query(): pass def to_arrow(self, if_empty: str = "warn") -> pa.Table: + """ + Creates a pyarrow table from source. + Args: + if_empty (str, optional): : What to do if data sourse contains no data. Defaults to "warn". + """ try: df = self.to_df(if_empty=if_empty) @@ -94,6 +99,14 @@ def to_csv( def to_excel( self, path: str, if_exists: str = "replace", if_empty: str = "warn" ) -> bool: + """ + Write from source to a excel file. + Args: + path (str): The destination path. + if_exists (str, optional): What to do if the file exists. Defaults to "replace". + if_empty (str, optional): What to do if the source contains no data. + + """ try: df = self.to_df(if_empty=if_empty) @@ -112,6 +125,7 @@ def to_excel( return True def _handle_if_empty(self, if_empty: str = None) -> NoReturn: + """What to do if empty.""" if if_empty == "warn": logger.warning("The query produced no data.") elif if_empty == "skip": @@ -188,7 +202,7 @@ def con(self) -> pyodbc.Connection: self._con.timeout = self.query_timeout return self._con - def run(self, query: str) -> List[Record]: + def run(self, query: str) -> Union[List[Record], bool]: cursor = self.con.cursor() cursor.execute(query) @@ -203,6 +217,11 @@ def run(self, query: str) -> List[Record]: return result def to_df(self, query: str, if_empty: str = None) -> pd.DataFrame: + """Creates DataFrame form SQL query. + Args: + query (str): SQL query. If don't start with "SELECT" returns empty DataFrame. + if_empty (str, optional): What to do if the query returns no data. Defaults to None. + """ conn = self.con if query.upper().startswith("SELECT"): df = pd.read_sql_query(query, conn) @@ -213,6 +232,11 @@ def to_df(self, query: str, if_empty: str = None) -> pd.DataFrame: return df def _check_if_table_exists(self, table: str, schema: str = None) -> bool: + """Checks if table exists. + Args: + table (str): Table name. + schema (str, optional): Schema name. Defaults to None. + """ exists_query = f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = '{schema}' AND TABLE_NAME='{table}'" exists = bool(self.run(exists_query)) return exists @@ -292,6 +316,7 @@ def insert_into(self, table: str, df: pd.DataFrame) -> str: return sql def _sql_column(self, column_name: str) -> str: + """Returns the name of a column""" if isinstance(column_name, str): out_name = f"'{column_name}'" else: diff --git a/viadot/sources/cloud_for_customers.py b/viadot/sources/cloud_for_customers.py index 1278131e5..39759b6dc 100644 --- a/viadot/sources/cloud_for_customers.py +++ b/viadot/sources/cloud_for_customers.py @@ -7,6 +7,7 @@ from ..utils import handle_api_response from ..exceptions import CredentialError import re +from copy import deepcopy class CloudForCustomers(Source): @@ -21,9 +22,12 @@ def __init__( credentials: Dict[str, Any] = None, **kwargs, ): - """ - Fetches data from Cloud for Customer. - Args: + """Cloud for Customers connector build for fetching Odata source. + See [pyodata docs](https://pyodata.readthedocs.io/en/latest/index.html) for an explanation + how Odata works. + + Parameters + ---------- report_url (str, optional): The url to the API in case of prepared report. Defaults to None. url (str, optional): The url to the API. Defaults to None. endpoint (str, optional): The endpoint of the API. Defaults to None. @@ -65,6 +69,9 @@ def change_to_meta_url(url: str) -> str: return meta_url def _to_records_report(self, url: str) -> List[Dict[str, Any]]: + """Fetches the data from source with report_url. + At first enter url is from function parameter. At next is generated automaticaly. + """ records = [] while url: response = self.get_response(url) @@ -77,37 +84,34 @@ def _to_records_report(self, url: str) -> List[Dict[str, Any]]: return records def _to_records_other(self, url: str) -> List[Dict[str, Any]]: + """Fetches the data from source with url. + At first enter url is a join of url and endpoint passed into this function. + At any other entering it bring `__next_url` adress, generated automatically, but without params. + """ records = [] - tmp_full_url = self.full_url - tmp_params = self.params + tmp_full_url = deepcopy(url) + tmp_params = deepcopy(self.params) while url: - response = self.get_response(tmp_full_url) + response = self.get_response(tmp_full_url, params=tmp_params) response_json = response.json() if isinstance(response_json["d"], dict): # ODATA v2+ API new_records = response_json["d"].get("results") - url = None - self.params = None - self.endpoint = None - url = response_json["d"].get("__next") - tmp_full_url = url - + url = response_json["d"].get("__next", None) else: # ODATA v1 new_records = response_json["d"] - url = None - self.params = None - self.endpoint = None - url = response_json.get("__next") - tmp_full_url = url - + url = response_json.get("__next", None) + # prevents concatenation of previous url's with params with the same params + tmp_params = None + tmp_full_url = url records.extend(new_records) - self.params = tmp_params - return records def to_records(self) -> List[Dict[str, Any]]: - """Download a list of entities in the records format""" + """ + Download a list of entities in the records format + """ if self.is_report: url = self.report_url return self._to_records_report(url=url) @@ -116,6 +120,17 @@ def to_records(self) -> List[Dict[str, Any]]: return self._to_records_other(url=url) def response_to_entity_list(self, dirty_json: Dict[str, Any], url: str) -> List: + + """Changing request json response to list. + + Args: + dirty_json (Dict[str, Any]): json from response. + url (str): the URL which trying to fetch metadata. + + Returns: + List: List of dictionaries. + """ + metadata_url = self.change_to_meta_url(url) column_maper_dict = self.map_columns(metadata_url) entity_list = [] @@ -133,11 +148,20 @@ def response_to_entity_list(self, dirty_json: Dict[str, Any], url: str) -> List: return entity_list def map_columns(self, url: str = None) -> Dict[str, str]: + + """Fetch metadata from url used to column name map. + + Args: + url (str, optional): the URL which trying to fetch metadata. Defaults to None. + + Returns: + Dict[str, str]: Property Name as key mapped to the value of sap label. + """ column_mapping = {} if url: username = self.credentials.get("username") pw = self.credentials.get("password") - response = requests.get(url, params=self.params, auth=(username, pw)) + response = requests.get(url, auth=(username, pw)) for sentence in response.text.split("/>"): result = re.search( r'(?<=Name=")([^"]+).+(sap:label=")([^"]+)+', sentence @@ -148,15 +172,37 @@ def map_columns(self, url: str = None) -> Dict[str, str]: column_mapping[key] = val return column_mapping - def get_response(self, url: str, timeout: tuple = (3.05, 60 * 30)) -> pd.DataFrame: + def get_response( + self, url: str, params: Dict[str, Any] = None, timeout: tuple = (3.05, 60 * 30) + ) -> requests.models.Response: + """Handle and raise Python exceptions during request. Using of url and service endpoint needs additional parameters + stores in params. report_url contain additional params in their structure. + In report_url scenario it can not contain params parameter. + + Args: + url (str): the URL which trying to connect. + params (Dict[str, Any], optional): Additional parameters like filter, used in case of normal url. + Defaults to None used in case of report_url, which can not contain params. + timeout (tuple, optional): the request times out. Defaults to (3.05, 60 * 30). + + Returns: + requests.models.Response + """ username = self.credentials.get("username") pw = self.credentials.get("password") response = handle_api_response( - url=url, params=self.params, auth=(username, pw), timeout=timeout + url=url, + params=params, + auth=(username, pw), + timeout=timeout, ) return response def to_df(self, fields: List[str] = None, if_empty: str = "warn") -> pd.DataFrame: + """Returns records in a pandas DataFrame. + Args: + fields (List[str], optional): List of fields to put in DataFrame. Defaults to None. + """ records = self.to_records() df = pd.DataFrame(data=records) if fields: diff --git a/viadot/sources/duckdb.py b/viadot/sources/duckdb.py new file mode 100644 index 000000000..3d7a00fdc --- /dev/null +++ b/viadot/sources/duckdb.py @@ -0,0 +1,213 @@ +from multiprocessing.sharedctypes import Value +from typing import Any, List, Literal, NoReturn, Tuple, Union + +import pandas as pd +from prefect.utilities import logging + +import duckdb + +from ..config import local_config +from ..exceptions import CredentialError +from ..signals import SKIP +from .base import Source + +logger = logging.get_logger(__name__) + +Record = Tuple[Any] + + +class DuckDB(Source): + DEFAULT_SCHEMA = "main" + + def __init__( + self, + config_key: str = "DuckDB", + credentials: dict = None, + *args, + **kwargs, + ): + """A class for interacting with DuckDB. + + Args: + config_key (str, optional): The key inside local config containing the config. + User can choose to use this or pass credentials directly to the `credentials` + parameter. Defaults to None. + credentials (dict, optional): Credentials for the connection. Defaults to None. + """ + + if config_key: + config_credentials = local_config.get(config_key) + + credentials = credentials if credentials else config_credentials + if credentials is None: + raise CredentialError("Credentials not found.") + + super().__init__(*args, credentials=credentials, **kwargs) + + @property + def con(self) -> duckdb.DuckDBPyConnection: + """Return a new connection to the database. As the views are highly isolated, + we need a new connection for each query in order to see the changes from + previous queries (eg. if we create a new table and then we want to list + tables from INFORMATION_SCHEMA, we need to create a new DuckDB connection). + + Returns: + duckdb.DuckDBPyConnection: database connection. + """ + return duckdb.connect( + database=self.credentials.get("database"), + read_only=self.credentials.get("read_only", False), + ) + + @property + def tables(self) -> List[str]: + """Show the list of fully qualified table names. + + Returns: + List[str]: The list of tables in the format '{SCHEMA}.{TABLE}'. + """ + tables_meta: List[Tuple] = self.run("SELECT * FROM information_schema.tables") + tables = [table_meta[1] + "." + table_meta[2] for table_meta in tables_meta] + return tables + + def to_df(self, query: str, if_empty: str = None) -> pd.DataFrame: + if query.upper().startswith("SELECT"): + df = self.run(query, fetch_type="dataframe") + if df.empty: + self._handle_if_empty(if_empty=if_empty) + else: + df = pd.DataFrame() + return df + + def run( + self, query: str, fetch_type: Literal["record", "dataframe"] = "record" + ) -> Union[List[Record], bool]: + """Run a query on DuckDB. + + Args: + query (str): The query to execute. + fetch_type (Literal[, optional): How to return the data: either + in the default record format or as a pandas DataFrame. Defaults to "record". + + Returns: + Union[List[Record], bool]: Either the result set of a query or, + in case of DDL/DML queries, a boolean describing whether + the query was excuted successfuly. + """ + allowed_fetch_type_values = ["record", "dataframe"] + if fetch_type not in allowed_fetch_type_values: + raise ValueError( + f"Only the values {allowed_fetch_type_values} are allowed for 'fetch_type'" + ) + cursor = self.con.cursor() + cursor.execute(query) + + query_clean = query.upper().strip() + query_keywords = ["SELECT", "SHOW", "PRAGMA"] + if any(query_clean.startswith(word) for word in query_keywords): + if fetch_type == "record": + result = cursor.fetchall() + else: + result = cursor.fetchdf() + else: + result = True + + cursor.close() + return result + + def _handle_if_empty(self, if_empty: str = None) -> NoReturn: + if if_empty == "warn": + logger.warning("The query produced no data.") + elif if_empty == "skip": + raise SKIP("The query produced no data. Skipping...") + elif if_empty == "fail": + raise ValueError("The query produced no data.") + + def create_table_from_parquet( + self, + table: str, + path: str, + schema: str = None, + if_exists: Literal["fail", "replace", "skip", "delete"] = "fail", + ) -> NoReturn: + """Create a DuckDB table with a CTAS from Parquet file(s). + + Args: + table (str): Destination table. + path (str): The path to the source Parquet file(s). Glob expressions are + also allowed here (eg. `my_folder/*.parquet`). + schema (str, optional): Destination schema. Defaults to None. + if_exists (Literal[, optional): What to do if the table already exists. Defaults to "fail". + + Raises: + ValueError: If the table exists and `if_exists` is set to `fail`. + + Returns: + NoReturn: Does not return anything. + """ + schema = schema or DuckDB.DEFAULT_SCHEMA + fqn = schema + "." + table + exists = self._check_if_table_exists(schema=schema, table=table) + + if exists: + if if_exists == "replace": + self.run(f"DROP TABLE {fqn}") + elif if_exists == "delete": + self.run(f"DELETE FROM {fqn}") + return True + elif if_exists == "fail": + raise ValueError( + "The table already exists and 'if_exists' is set to 'fail'." + ) + elif if_exists == "skip": + return False + + schema_exists = self._check_if_schema_exists(schema) + if not schema_exists: + self.run(f"CREATE SCHEMA {schema}") + + self.logger.info(f"Creating table {fqn}...") + ingest_query = f"CREATE TABLE {fqn} AS SELECT * FROM '{path}';" + self.run(ingest_query) + self.logger.info(f"Table {fqn} has been created successfully.") + + def insert_into_from_parquet(): + # check with Marcin if needed + pass + + def drop_table(self, table: str, schema: str = None) -> bool: + """ + Drop a table. + + This is a thin wraper around DuckDB.run() which logs to the operation. + + Args: + table (str): The table to be dropped. + schema (str, optional): The schema where the table is located. + Defaults to None. + + Returns: + bool: Whether the table was dropped. + """ + + schema = schema or DuckDB.DEFAULT_SCHEMA + fqn = schema + "." + table + + self.logger.info(f"Dropping table {fqn}...") + dropped = self.run(f"DROP TABLE IF EXISTS {fqn}") + if dropped: + self.logger.info(f"Table {fqn} has been dropped successfully.") + else: + self.logger.info(f"Table {fqn} could not be dropped.") + return dropped + + def _check_if_table_exists(self, table: str, schema: str = None) -> bool: + schema = schema or DuckDB.DEFAULT_SCHEMA + fqn = schema + "." + table + return fqn in self.tables + + def _check_if_schema_exists(self, schema: str) -> bool: + if schema == self.DEFAULT_SCHEMA: + return True + fqns = self.tables + return any((fqn.split(".")[0] == schema for fqn in fqns)) diff --git a/viadot/sources/sap_rfc.py b/viadot/sources/sap_rfc.py new file mode 100644 index 000000000..f45a9645c --- /dev/null +++ b/viadot/sources/sap_rfc.py @@ -0,0 +1,432 @@ +import re +from collections import OrderedDict +from typing import List, Literal, Union, Tuple, OrderedDict as OrderedDictType + +import pandas as pd + +try: + import pyrfc +except ModuleNotFoundError: + raise ImportError("pyfrc is required to use the SAPRFC source.") +from sql_metadata import Parser +from viadot.config import local_config +from viadot.exceptions import CredentialError +from viadot.sources.base import Source + + +def remove_whitespaces(text): + return " ".join(text.split()) + + +def get_keyword_for_condition(where: str, condition: str) -> str: + where = where[: where.find(condition)] + return where.split()[-1] + + +def get_where_uppercased(where: str) -> str: + """ + Uppercase a WHERE clause's keywords without + altering the original string. + """ + where_and_uppercased = re.sub("\\sand ", " AND ", where) + where_and_and_or_uppercased = re.sub("\\sor ", " OR ", where_and_uppercased) + return where_and_and_or_uppercased + + +def remove_last_condition(where: str) -> str: + """Remove the last condtion from a WHERE clause.""" + where = get_where_uppercased(where) + split_by_and = re.split("\\sAND ", where) + conditions = [re.split("\\sOR ", expr) for expr in split_by_and] + conditions_flattened = [ + condition for sublist in conditions for condition in sublist + ] + + condition_to_remove = conditions_flattened[-1] + + where_trimmed = where[: where.find(condition_to_remove)].split() + where_trimmed_without_last_keyword = " ".join(where_trimmed[:-1]) + + return where_trimmed_without_last_keyword, condition_to_remove + + +def trim_where(where: str) -> Tuple[str, OrderedDictType[str, str]]: + """ + Trim a WHERE clause to 75 characters or less, + as required by SAP. The rest of filters will be applied + in-memory on client side. + """ + + if len(where) <= 75: + return where, None + + wheres_to_add = OrderedDict() + keywords_with_conditions = [] + where_trimmed = where + while len(where_trimmed) > 75: + # trim the where + where_trimmed, removed_condition = remove_last_condition(where_trimmed) + + # store the removed conditions so we can readd them later + keyword = get_keyword_for_condition(where, removed_condition) + keywords_with_conditions.append((keyword, removed_condition)) + + wheres_to_add_sorted = keywords_with_conditions[::-1] + wheres_to_add = OrderedDict(wheres_to_add_sorted) + + return where_trimmed, wheres_to_add + + +class SAPRFC(Source): + """ + A class for querying SAP with SQL using the RFC protocol. + + Note that only a very limited subset of SQL is supported: + - aliases + - where clauses combined using the AND operator + - limit & offset + + Unsupported: + - aggregations + - joins + - subqueries + - etc. + """ + + def __init__(self, sep: str = None, autopick_sep: bool = True, *args, **kwargs): + """Create an instance of the SAPRFC class. + + Args: + sep (str, optional): Which separator to use when querying SAP. Defaults to None. + autopick_sep (bool, optional): Whether to automatically pick a working separator. + Defaults to True. + + Raises: + CredentialError: If provided credentials are incorrect. + """ + + self._con = None + DEFAULT_CREDENTIALS = local_config.get("SAP").get("DEV") + credentials = kwargs.pop("credentials", None) or DEFAULT_CREDENTIALS + if credentials is None: + raise CredentialError("Missing credentials.") + + super().__init__(*args, credentials=credentials, **kwargs) + + self.sep = sep + self.autopick_sep = autopick_sep + self.client_side_filters = None + + @property + def con(self) -> pyrfc.Connection: + if self._con is not None: + return self._con + con = pyrfc.Connection(**self.credentials) + self._con = con + return con + + def check_connection(self) -> None: + self.logger.info("Checking the connection...") + self.con.ping() + self.logger.info("Connection has been validated successfully.") + + def get_function_parameters( + self, + function_name: str, + description: Union[None, Literal["short", "long"]] = "short", + *args, + ) -> Union[List[str], pd.DataFrame]: + """Get the description for a SAP RFC function. + + Args: + function_name (str): The name of the function to detail. + description (Union[None, Literal[, optional): Whether to display + a short or a long description. Defaults to "short". + + Raises: + ValueError: If the argument for description is incorrect. + + Returns: + Union[List[str], pd.DataFrame]: Either a list of the function's + parameter names (if 'description' is set to None), + or a short or long description. + """ + if description is not None: + if description not in ["short", "long"]: + raise ValueError( + "Incorrect value for 'description'. Correct values: (None, 'short', 'long'" + ) + + descr = self.con.get_function_description(function_name, *args) + param_names = [param["name"] for param in descr.parameters] + detailed_params = descr.parameters + filtered_detailed_params = [ + { + "name": param["name"], + "parameter_type": param["parameter_type"], + "default_value": param["default_value"], + "optional": param["optional"], + "parameter_text": param["parameter_text"], + } + for param in descr.parameters + ] + + if description is not None: + if description == "long": + params = detailed_params + else: + params = filtered_detailed_params + params = pd.DataFrame.from_records(params) + else: + params = param_names + + return params + + def _get_where_condition(self, sql: str) -> str: + """Retrieve the WHERE conditions from a SQL query. + + Args: + sql (str): The input SQL query. + + Raises: + ValueError: Raised if the WHERE clause is longer than + 75 characters (SAP's limitation) and the condition for the + extra clause(s) is OR. + + Returns: + str: The where clause trimmed to <= 75 characters. + """ + + where_match = re.search("\\sWHERE ", sql.upper()) + if not where_match: + return None + + limit_match = re.search("\\sLIMIT ", sql.upper()) + limit_pos = limit_match.span()[0] if limit_match else len(sql) + + where = sql[where_match.span()[1] : limit_pos] + where_sanitized = remove_whitespaces(where) + where_trimmed, client_side_filters = trim_where(where_sanitized) + if client_side_filters: + self.logger.warning( + "A WHERE clause longer than 75 character limit detected." + ) + if "OR" in [key.upper() for key in client_side_filters.keys()]: + raise ValueError( + "WHERE conditions after the 75 character limit can only be combined with the AND keyword." + ) + else: + filters_pretty = list(client_side_filters.items()) + self.logger.warning( + f"Trimmed conditions ({filters_pretty}) will be applied client-side." + ) + self.logger.warning(f"See the documentation for caveats.") + + self.client_side_filters = client_side_filters + return where_trimmed + + @staticmethod + def _get_table_name(sql: str) -> str: + parsed = Parser(sql) + if len(parsed.tables) > 1: + raise ValueError("Querying more than one table is not supported.") + return parsed.tables[0] + + def _build_pandas_filter_query( + self, client_side_filters: OrderedDictType[str, str] + ) -> str: + """Build a WHERE clause that will be applied client-side. + This is required if the WHERE clause passed to query() is + longer than 75 characters. + + Args: + client_side_filters (OrderedDictType[str, str]): The + client-side filters to apply. + + Returns: + str: the WHERE clause reformatted to fit the format + required by DataFrame.query(). + """ + for i, f in enumerate(client_side_filters.items()): + if i == 0: + # skip the first keyword; we assume it's "AND" + query = f[1] + else: + query += " " + f[0] + " " + f[1] + + filter_column_name = f[1].split()[0] + resolved_column_name = self._resolve_col_name(filter_column_name) + query = re.sub("\\s?=\\s?", " == ", query).replace( + filter_column_name, resolved_column_name + ) + return query + + def extract_values(self, sql: str) -> None: + """TODO: This should cover all values, not just columns""" + self.where = self._get_where_condition(sql) + self.select_columns = self._get_columns(sql, aliased=False) + self.select_columns_aliased = self._get_columns(sql, aliased=True) + + def _resolve_col_name(self, column: str) -> str: + """Get aliased column name if it exists, otherwise return column name.""" + return self.aliases_keyed_by_columns.get(column, column) + + def _get_columns(self, sql: str, aliased: bool = False) -> List[str]: + """Retrieve column names from a SQL query. + + Args: + sql (str): The SQL query to parse. + aliased (bool, optional): Whether to returned aliased + names. Defaults to False. + + Returns: + List[str]: A list of column names. + """ + parsed = Parser(sql) + columns = list(parsed.columns_dict["select"]) + if aliased: + aliases_keyed_by_alias = parsed.columns_aliases + aliases_keyed_by_columns = OrderedDict( + {val: key for key, val in aliases_keyed_by_alias.items()} + ) + + self.aliases_keyed_by_columns = aliases_keyed_by_columns + + columns = [ + aliases_keyed_by_columns[col] + if col in aliases_keyed_by_columns + else col + for col in columns + ] + + if self.client_side_filters: + # In case the WHERE clause is > 75 characters long, we execute the rest of the filters + # client-side. To do this, we need to pull all fields in the client-side WHERE conditions. + # Below code adds these columns to the list of SELECTed fields. + cols_to_add = [v.split()[0] for v in self.client_side_filters.values()] + if aliased: + cols_to_add = [aliases_keyed_by_columns[col] for col in cols_to_add] + columns.extend(cols_to_add) + columns = list(dict.fromkeys(columns)) # remove duplicates + + return columns + + @staticmethod + def _get_limit(sql: str) -> int: + """Get limit from the query""" + limit_match = re.search("\\sLIMIT ", sql.upper()) + if not limit_match: + return None + + return int(sql[limit_match.span()[1] :].split()[0]) + + @staticmethod + def _get_offset(sql: str) -> int: + """Get offset from the query""" + offset_match = re.search("\\sOFFSET ", sql.upper()) + if not offset_match: + return None + + return int(sql[offset_match.span()[1] :].split()[0]) + + def query(self, sql: str, sep: str = None) -> None: + """Parse an SQL query into pyRFC commands and save it into + an internal dictionary. + + Args: + sql (str): The SQL query to be ran. + sep (str, optional): The separator to be used + to split columns in the result blob. Defaults to self.sep. + + Raises: + ValueError: If the query is not a SELECT query. + """ + + if not sql.strip().upper().startswith("SELECT"): + raise ValueError("Only SELECT queries are supported.") + + sep = sep if sep is not None else self.sep + + self.sql = sql + + self.extract_values(sql) + + table_name = self._get_table_name(sql) + # this has to be called before checking client_side_filters + where = self.where + columns = self.select_columns + options = [{"TEXT": where}] if where else None + limit = self._get_limit(sql) + offset = self._get_offset(sql) + query_json = dict( + QUERY_TABLE=table_name, + FIELDS=columns, + OPTIONS=options, + ROWCOUNT=limit, + ROWSKIPS=offset, + DELIMITER=sep, + ) + # SAP doesn't understand None, so we filter out non-specified parameters + query_json_filtered = { + key: query_json[key] for key in query_json if query_json[key] is not None + } + self._query = query_json_filtered + + def call(self, func: str, *args, **kwargs): + """Call a SAP RFC function""" + return self.con.call(func, *args, **kwargs) + + def _get_alias(self, column: str) -> str: + return self.aliases_keyed_by_columns.get(column, column) + + def _get_client_side_filter_cols(self): + return [f[1].split()[0] for f in self.client_side_filters.items()] + + def to_df(self): + """ + Load the results of a query into a pandas DataFrame. + + Due to SAP limitations, if the length of the WHERE clause is longer than 75 + characters, we trim whe WHERE clause and perform the rest of the filtering + on the resulting DataFrame. Eg. if the WHERE clause contains 4 conditions + and has 80 characters, we only perform 3 filters in the query, and perform + the last filter on the DataFrame. + + Source: https://success.jitterbit.com/display/DOC/Guide+to+Using+RFC_READ_TABLE+to+Query+SAP+Tables#GuidetoUsingRFC_READ_TABLEtoQuerySAPTables-create-the-operation + - WHERE clause: 75 character limit + - SELECT: 512 character row limit + + Returns: + pd.DataFrame: A DataFrame representing the result of the query provided in `PyRFC.query()`. + """ + params = self._query + columns = self.select_columns_aliased + sep = self._query.get("DELIMITER") + + if sep is None or self.autopick_sep: + SEPARATORS = ["|", "/t", "#", ";", "@"] + for sep in SEPARATORS: + self._query["DELIMITER"] = sep + try: + response = self.call("RFC_READ_TABLE", **params) + record_key = "WA" + data_raw = response["DATA"] + records = [row[record_key].split(sep) for row in data_raw] + except ValueError: + continue + df = pd.DataFrame(records, columns=columns) + + if self.client_side_filters: + filter_query = self._build_pandas_filter_query(self.client_side_filters) + df.query(filter_query, inplace=True) + client_side_filter_cols_aliased = [ + self._get_alias(col) for col in self._get_client_side_filter_cols() + ] + cols_to_drop = [ + col + for col in client_side_filter_cols_aliased + if col not in self.select_columns_aliased + ] + df.drop(cols_to_drop, axis=1, inplace=True) + + return df diff --git a/viadot/sources/sharepoint.py b/viadot/sources/sharepoint.py index 8accf19f7..49c925015 100644 --- a/viadot/sources/sharepoint.py +++ b/viadot/sources/sharepoint.py @@ -51,7 +51,11 @@ def download_file( download_from_path: str = None, download_to_path: str = "Sharepoint_file.xlsm", ) -> None: - + """Function to download files from Sharepoint. + Args: + download_from_path (str): Path from which to download file. Defaults to None. + download_to_path (str, optional): Path to destination file. Defaults to "Sharepoint_file.xlsm". + """ download_from_path = download_from_path or self.url if not download_from_path: raise ValueError("Missing required parameter 'download_from_path'.") diff --git a/viadot/sources/sqlite.py b/viadot/sources/sqlite.py index 8cbcf87b0..2e778f408 100644 --- a/viadot/sources/sqlite.py +++ b/viadot/sources/sqlite.py @@ -40,6 +40,11 @@ def conn_str(self): return conn_str def _check_if_table_exists(self, table: str, schema: str = None) -> bool: + """Checks if table exists. + Args: + table (str): Table name. + schema (str, optional): Schema name. Defaults to None. + """ fqn = f"{schema}.{table}" if schema is not None else table exists_query = ( f"SELECT name FROM sqlite_master WHERE type='table' AND name='{fqn}'" diff --git a/viadot/sources/supermetrics.py b/viadot/sources/supermetrics.py index 92b7a88eb..77cef4687 100644 --- a/viadot/sources/supermetrics.py +++ b/viadot/sources/supermetrics.py @@ -47,6 +47,7 @@ def __init__(self, *args, query_params: Dict[str, Any] = None, **kwargs): @classmethod def get_params_from_api_query(cls, url: str) -> Dict[str, Any]: + """Returns parmeters from API query in a dictionary""" url_unquoted = urllib.parse.unquote(url) s = urllib.parse.parse_qs(url_unquoted) endpoint = list(s.keys())[0] @@ -87,6 +88,7 @@ def _get_col_names_google_analytics( cls, response: dict, ) -> List[str]: + """Returns list of Google Analytics columns names""" # Supermetrics allows pivoting GA data, in which case it generates additional columns, # which are not enlisted in response's query metadata but are instead added as the first row of data. @@ -109,11 +111,13 @@ def _get_col_names_google_analytics( @classmethod def _get_col_names_other(cls, response: dict) -> List[str]: + """Returns list of columns names (to Google Analytics use _get_col_names_google_analytics ()""" cols_meta = response["meta"]["query"]["fields"] columns = [col_meta["field_name"] for col_meta in cols_meta] return columns def _get_col_names(self) -> List[str]: + """Returns list of columns names""" query_params_cp = deepcopy(self.query_params) query_params_cp["offset_start"] = 0 diff --git a/viadot/sources/uk_carbon_intensity.py b/viadot/sources/uk_carbon_intensity.py index 8a4bc9ecb..91b4f9c83 100644 --- a/viadot/sources/uk_carbon_intensity.py +++ b/viadot/sources/uk_carbon_intensity.py @@ -22,6 +22,7 @@ def __init__(self, *args, api_url: str = None, **kwargs): self.API_ENDPOINT = "https://api.carbonintensity.org.uk" def to_json(self): + """Creates json file""" url = f"{self.API_ENDPOINT}{self.api_url}" headers = {"Accept": "application/json"} response = requests.get(url, params={}, headers=headers) @@ -78,7 +79,7 @@ def to_df(self, if_empty: str = "warn"): "min": min_, } ) - return df + return df def query(self, api_url: str): self.api_url = api_url diff --git a/viadot/task_utils.py b/viadot/task_utils.py index aeb7375d9..059222c38 100644 --- a/viadot/task_utils.py +++ b/viadot/task_utils.py @@ -66,6 +66,12 @@ def dtypes_to_json_task(dtypes_dict, local_json_path: str): @task def chunk_df(df: pd.DataFrame, size: int = 10_000) -> List[pd.DataFrame]: + """ + Creates pandas Dataframes list of chunks with a given size. + Args: + df (pd.DataFrame): Input pandas DataFrame. + size (int, optional): Size of a chunk. Defaults to 10000. + """ n_rows = df.shape[0] chunks = [df[i : i + size] for i in range(0, n_rows, size)] return chunks @@ -73,6 +79,11 @@ def chunk_df(df: pd.DataFrame, size: int = 10_000) -> List[pd.DataFrame]: @task def df_get_data_types_task(df: pd.DataFrame) -> dict: + """ + Returns dictionary containing datatypes of pandas DataFrame columns. + Args: + df (pd.DataFrame): Input pandas DataFrame. + """ typeset = CompleteSet() dtypes = infer_type(df, typeset) dtypes_dict = {k: str(v) for k, v in dtypes.items()} @@ -127,16 +138,35 @@ def df_to_csv( if_exists: Literal["append", "replace", "skip"] = "replace", **kwargs, ) -> None: + + """ + Task to create csv file based on pandas DataFrame. + Args: + df (pd.DataFrame): Input pandas DataFrame. + path (str): Path to output csv file. + sep (str, optional): The separator to use in the CSV. Defaults to "\t". + if_exists (Literal["append", "replace", "skip"], optional): What to do if the table exists. Defaults to "replace". + """ + if if_exists == "append" and os.path.isfile(path): csv_df = pd.read_csv(path, sep=sep) out_df = pd.concat([csv_df, df]) elif if_exists == "replace": out_df = df - elif if_exists == "skip": + elif if_exists == "skip" and os.path.isfile(path): logger.info("Skipped.") return else: out_df = df + + # create directories if they don't exist + try: + if not os.path.isfile(path): + directory = os.path.dirname(path) + os.makedirs(directory, exist_ok=True) + except: + pass + out_df.to_csv(path, index=False, sep=sep) @@ -147,6 +177,13 @@ def df_to_parquet( if_exists: Literal["append", "replace", "skip"] = "replace", **kwargs, ) -> None: + """ + Task to create parquet file based on pandas DataFrame. + Args: + df (pd.DataFrame): Input pandas DataFrame. + path (str): Path to output parquet file. + if_exists (Literal["append", "replace", "skip"], optional): What to do if the table exists. Defaults to "replace". + """ if if_exists == "append" and os.path.isfile(path): parquet_df = pd.read_parquet(path) out_df = pd.concat([parquet_df, df]) @@ -162,18 +199,34 @@ def df_to_parquet( @task def dtypes_to_json(dtypes_dict: dict, local_json_path: str) -> None: + """ + Creates json file from a dictionary. + Args: + dtypes_dict (dict): Dictionary containing data types. + local_json_path (str): Path to local json file. + """ with open(local_json_path, "w") as fp: json.dump(dtypes_dict, fp) @task def union_dfs_task(dfs: List[pd.DataFrame]): + """ + Create one DataFrame from a list of pandas DataFrames. + Args: + dfs (List[pd.DataFrame]): List of pandas Dataframes to concat. In case of different size of DataFrames NaN values can appear. + """ return pd.concat(dfs, ignore_index=True) @task def write_to_json(dict_, path): - + """ + Creates json file from a dictionary. Log record informs about the writing file proccess. + Args: + dict_ (dict): Dictionary. + path (str): Path to local json file. + """ logger = prefect.context.get("logger") if os.path.isfile(path): diff --git a/viadot/tasks/__init__.py b/viadot/tasks/__init__.py index 000e3d42e..7860a77df 100644 --- a/viadot/tasks/__init__.py +++ b/viadot/tasks/__init__.py @@ -16,6 +16,7 @@ AzureSQLCreateTable, AzureSQLDBQuery, CreateTableFromBlob, + CheckColumnOrder, ) from .bcp import BCPTask from .github import DownloadGitHubFile @@ -24,3 +25,10 @@ from .supermetrics import SupermetricsToCSV, SupermetricsToDF from .sharepoint import SharepointToDF from .cloud_for_customers import C4CReportToDF, C4CToDF + +try: + from .sap_rfc import SAPRFCToDF +except ImportError: + pass + +from .duckdb import DuckDBCreateTableFromParquet, DuckDBQuery diff --git a/viadot/tasks/azure_blob_storage.py b/viadot/tasks/azure_blob_storage.py index a80e3c1d9..b15399bf2 100644 --- a/viadot/tasks/azure_blob_storage.py +++ b/viadot/tasks/azure_blob_storage.py @@ -6,6 +6,10 @@ class BlobFromCSV(Task): + """ + Task for generating Azure Blob Storage from CSV file + """ + def __init__(self, *args, **kwargs): super().__init__(name="csv_to_blob_storage", *args, **kwargs) diff --git a/viadot/tasks/azure_key_vault.py b/viadot/tasks/azure_key_vault.py index 7f88d02e6..428b3d457 100644 --- a/viadot/tasks/azure_key_vault.py +++ b/viadot/tasks/azure_key_vault.py @@ -12,6 +12,16 @@ def get_key_vault( credentials: str, secret_client_kwargs: dict, vault_name: str = None ) -> SecretClient: + """ + Get Azure Key Vault. + + Args: + credentials (str): Azure Key Vault credentials. + secret_client_kwargs (dict): Keyword arguments to forward to the SecretClient. + vault_name (str, optional): The name of the vault. Defaults to None + + Returns: Azure Key Vault + """ if not vault_name: vault_name = PrefectSecret("AZURE_DEFAULT_KEYVAULT").run() if credentials: diff --git a/viadot/tasks/azure_sql.py b/viadot/tasks/azure_sql.py index 8db166df0..07c3beb14 100644 --- a/viadot/tasks/azure_sql.py +++ b/viadot/tasks/azure_sql.py @@ -1,6 +1,7 @@ import json from datetime import timedelta -from typing import Any, Dict, Literal +from typing import Any, Dict, List, Literal +import pandas as pd from prefect import Task from prefect.tasks.secrets import PrefectSecret @@ -9,8 +10,21 @@ from ..sources import AzureSQL from .azure_key_vault import AzureKeyVaultSecret +from ..exceptions import ValidationError + def get_credentials(credentials_secret: str, vault_name: str = None): + """ + Get Azure credentials. + + Args: + credentials_secret (str): The name of the Azure Key Vault secret containing a dictionary + with SQL db credentials (server, db_name, user and password). + vault_name (str, optional): The name of the vault from which to obtain the secret. Defaults to None. + + Returns: Credentials + + """ if not credentials_secret: # attempt to read a default for the service principal secret name try: @@ -264,3 +278,74 @@ def run( self.logger.info(f"Successfully ran the query.") return result + + +class CheckColumnOrder(Task): + """ + Task for checking the order of columns in the loaded DF and in the SQL table into which the data from DF will be loaded. + If order is different then DF columns are reordered according to the columns of the SQL table. + """ + + def __init__( + self, + table: str = None, + if_exists: Literal["fail", "replace", "append", "delete"] = "replace", + df: pd.DataFrame = None, + credentials_secret: str = None, + vault_name: str = None, + *args, + **kwargs, + ): + self.credentials_secret = credentials_secret + self.vault_name = vault_name + + super().__init__(name="run_check_column_order", *args, **kwargs) + + def df_change_order( + self, df: pd.DataFrame = None, sql_column_list: List[str] = None + ): + df_column_list = list(df.columns) + if set(df_column_list) == set(sql_column_list): + df_changed = df.loc[:, sql_column_list] + else: + raise ValidationError( + "Detected discrepancies in number of columns or different column names between the CSV file and the SQL table!" + ) + + return df_changed + + def run( + self, + table: str = None, + if_exists: Literal["fail", "replace", "append", "delete"] = "replace", + df: pd.DataFrame = None, + credentials_secret: str = None, + vault_name: str = None, + ): + """ + Run a checking column order + + Args: + table (str, optional): SQL table name without schema. Defaults to None. + if_exists (Literal, optional): What to do if the table exists. Defaults to "replace". + df (pd.DataFrame, optional): Data Frame. Defaults to None. + credentials_secret (str, optional): The name of the Azure Key Vault secret containing a dictionary + with SQL db credentials (server, db_name, user, and password). Defaults to None. + vault_name (str, optional): The name of the vault from which to obtain the secret. Defaults to None. + """ + credentials = get_credentials(credentials_secret, vault_name=vault_name) + azure_sql = AzureSQL(credentials=credentials) + + if if_exists not in ["replace", "fail"]: + query = f"SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = '{table}'" + result = azure_sql.run(query=query) + sql_column_list = [table for row in result for table in row] + df_column_list = list(df.columns) + + if sql_column_list != df_column_list: + self.logger.warning( + "Detected column order difference between the CSV file and the table. Reordering..." + ) + df = self.df_change_order(df=df, sql_column_list=sql_column_list) + else: + self.logger.info("The table will be replaced.") diff --git a/viadot/tasks/cloud_for_customers.py b/viadot/tasks/cloud_for_customers.py index b46daa3ab..aefec958e 100644 --- a/viadot/tasks/cloud_for_customers.py +++ b/viadot/tasks/cloud_for_customers.py @@ -1,8 +1,12 @@ from prefect import task, Task +import json import pandas as pd from ..sources import CloudForCustomers from typing import Any, Dict, List from prefect.utilities.tasks import defaults_from_attrs +from prefect.tasks.secrets import PrefectSecret +from .azure_key_vault import AzureKeyVaultSecret +from viadot.config import local_config class C4CReportToDF(Task): @@ -43,6 +47,8 @@ def run( env: str = "QA", skip: int = 0, top: int = 1000, + credentials_secret: str = None, + vault_name: str = None, ): """ Task for downloading data from the Cloud for Customers to a pandas DataFrame using report URL @@ -54,15 +60,35 @@ def run( env (str, optional): The development environments. Defaults to 'QA'. skip (int, optional): Initial index value of reading row. Defaults to 0. top (int, optional): The value of top reading row. Defaults to 1000. + credentials_secret (str, optional): The name of the Azure Key Vault secret containing a dictionary + with C4C credentials. Defaults to None. + vault_name (str, optional): The name of the vault from which to obtain the secret. Defaults to None. Returns: pd.DataFrame: The query result as a pandas DataFrame. """ + + if not credentials_secret: + try: + credentials_secret = PrefectSecret("C4C_KV").run() + except ValueError: + pass + + if credentials_secret: + credentials_str = AzureKeyVaultSecret( + credentials_secret, vault_name=vault_name + ).run() + credentials = json.loads(credentials_str)[env] + else: + credentials = local_config.get("CLOUD_FOR_CUSTOMERS")[env] + final_df = pd.DataFrame() next_batch = True while next_batch: new_url = f"{report_url}&$top={top}&$skip={skip}" - chunk_from_url = CloudForCustomers(report_url=new_url, env=env) + chunk_from_url = CloudForCustomers( + report_url=new_url, env=env, credentials=credentials + ) df = chunk_from_url.to_df() final_df = final_df.append(df) if not final_df.empty: @@ -110,6 +136,8 @@ def run( fields: List[str] = None, params: List[str] = None, if_empty: str = "warn", + credentials_secret: str = None, + vault_name: str = None, ): """ Task for downloading data from the Cloud for Customers to a pandas DataFrame using normal URL (with query parameters). @@ -128,12 +156,33 @@ def run( fields (List[str], optional): The C4C Table fields. Defaults to None. params (Dict[str, Any]): The query parameters like filter by creation date time. Defaults to json format. if_empty (str, optional): What to do if query returns no data. Defaults to "warn". + credentials_secret (str, optional): The name of the Azure Key Vault secret containing a dictionary + with C4C credentials. Defaults to None. + vault_name (str, optional): The name of the vault from which to obtain the secret. Defaults to None. Returns: pd.DataFrame: The query result as a pandas DataFrame. """ + if not credentials_secret: + try: + credentials_secret = PrefectSecret("C4C_KV").run() + except ValueError: + pass + + if credentials_secret: + credentials_str = AzureKeyVaultSecret( + credentials_secret, vault_name=vault_name + ).run() + credentials = json.loads(credentials_str)[env] + else: + credentials = local_config.get("CLOUD_FOR_CUSTOMERS")[env] cloud_for_customers = CloudForCustomers( - url=url, params=params, endpoint=endpoint, env=env, fields=fields + url=url, + params=params, + endpoint=endpoint, + env=env, + fields=fields, + credentials=credentials, ) df = cloud_for_customers.to_df(if_empty=if_empty, fields=fields) diff --git a/viadot/tasks/duckdb.py b/viadot/tasks/duckdb.py new file mode 100644 index 000000000..cb4d5e3c8 --- /dev/null +++ b/viadot/tasks/duckdb.py @@ -0,0 +1,130 @@ +from typing import Any, List, Literal, Tuple, Union, NoReturn + +from prefect import Task +from prefect.utilities.tasks import defaults_from_attrs + +from ..sources import DuckDB + +Record = Tuple[Any] + + +class DuckDBQuery(Task): + """ + Task for running a query on DuckDB. + + Args: + credentials (dict, optional): The config to use for connecting with the db. + """ + + def __init__( + self, + credentials: dict = None, + *args, + **kwargs, + ): + self.credentials = credentials + super().__init__(name="run_duckdb_query", *args, **kwargs) + + @defaults_from_attrs("credentials") + def run( + self, + query: str, + fetch_type: Literal["record", "dataframe"] = "record", + credentials: dict = None, + ) -> Union[List[Record], bool]: + """Run a query on DuckDB. + + Args: + query (str, required): The query to execute. + fetch_type (Literal[, optional): How to return the data: either + in the default record format or as a pandas DataFrame. Defaults to "record". + credentials (dict, optional): The config to use for connecting with the db. + + Returns: + Union[List[Record], bool]: Either the result set of a query or, + in case of DDL/DML queries, a boolean describing whether + the query was excuted successfuly. + """ + + duckdb = DuckDB(credentials=credentials) + + # run the query and fetch the results if it's a select + result = duckdb.run(query, fetch_type=fetch_type) + + self.logger.info(f"Successfully ran the query.") + return result + + +class DuckDBCreateTableFromParquet(Task): + """ + Task for creating a DuckDB table with a CTAS from Parquet file(s). + + Args: + table (str, optional): Destination table. + also allowed here (eg. `my_folder/*.parquet`). + schema (str, optional): Destination schema. + if_exists (Literal, optional): What to do if the table already exists. + credentials(dict, optional): The config to use for connecting with the db. + + Raises: + ValueError: If the table exists and `if_exists` is set to `fail`. + + Returns: + NoReturn: Does not return anything. + """ + + def __init__( + self, + schema: str = None, + if_exists: Literal["fail", "replace", "skip", "delete"] = "fail", + credentials: dict = None, + *args, + **kwargs, + ): + self.schema = schema + self.if_exists = if_exists + self.credentials = credentials + + super().__init__( + name="duckdb_create_table", + *args, + **kwargs, + ) + + @defaults_from_attrs("schema", "if_exists") + def run( + self, + table: str, + path: str, + schema: str = None, + if_exists: Literal["fail", "replace", "skip", "delete"] = None, + ) -> NoReturn: + """ + Create a DuckDB table with a CTAS from Parquet file(s). + + Args: + table (str, optional): Destination table. + path (str): The path to the source Parquet file(s). Glob expressions are + also allowed here (eg. `my_folder/*.parquet`). + schema (str, optional): Destination schema. + if_exists (Literal, optional): What to do if the table already exists. + + Raises: + ValueError: If the table exists and `if_exists` is set to `fail`. + + Returns: + NoReturn: Does not return anything. + """ + + duckdb = DuckDB(credentials=self.credentials) + + fqn = f"{schema}.{table}" if schema is not None else table + created = duckdb.create_table_from_parquet( + path=path, schema=schema, table=table, if_exists=if_exists + ) + if created: + self.logger.info(f"Successfully created table {fqn}.") + else: + self.logger.info( + f"Table {fqn} has not been created as if_exists is set to {if_exists}." + ) diff --git a/viadot/tasks/great_expectations.py b/viadot/tasks/great_expectations.py index 1440e4c00..752cc6d47 100644 --- a/viadot/tasks/great_expectations.py +++ b/viadot/tasks/great_expectations.py @@ -161,6 +161,7 @@ def run( def _get_stats_from_results( self, result: ValidationOperatorResult ) -> Tuple[int, int]: + """Returns Tuple containing number of successful and evaluated expectations""" result_identifier = result.list_validation_result_identifiers()[0] stats = result._list_validation_statistics()[result_identifier] n_successful = stats["successful_expectations"] diff --git a/viadot/tasks/sap_rfc.py b/viadot/tasks/sap_rfc.py new file mode 100644 index 000000000..d44dac2bb --- /dev/null +++ b/viadot/tasks/sap_rfc.py @@ -0,0 +1,98 @@ +from datetime import timedelta + +import pandas as pd +from prefect import Task +from prefect.utilities.tasks import defaults_from_attrs + +try: + from ..sources import SAPRFC +except ImportError: + raise + + +class SAPRFCToDF(Task): + def __init__( + self, + query: str = None, + sep: str = "\t", + autopick_sep: bool = True, + credentials: dict = None, + max_retries: int = 3, + retry_delay: timedelta = timedelta(seconds=10), + *args, + **kwargs, + ): + """ + A task for querying SAP with SQL using the RFC protocol. + + Note that only a very limited subset of SQL is supported: + - aliases + - where clauses combined using the AND operator + - limit & offset + + Unsupported: + - aggregations + - joins + - subqueries + - etc. + + Args: + query (str, optional): The query to be executed with pyRFC. + sep (str, optional): The separator to use when reading query results. Defaults to "\t". + autopick_sep (str, optional): Whether SAPRFC should try different separators in case + the query fails with the default one. + credentials (dict, optional): The credentials to use to authenticate with SAP. + By default, they're taken from the local viadot config. + """ + self.query = query + self.sep = sep + self.autopick_sep = autopick_sep + self.credentials = credentials + + super().__init__( + name="sap_rfc_to_df", + max_retries=max_retries, + retry_delay=retry_delay, + *args, + **kwargs, + ) + + @defaults_from_attrs( + "query", + "sep", + "autopick_sep", + "credentials", + "max_retries", + "retry_delay", + ) + def run( + self, + query: str = None, + sep: str = None, + autopick_sep: bool = None, + credentials: dict = None, + max_retries: int = None, + retry_delay: timedelta = None, + ) -> pd.DataFrame: + """Task run method. + + Args: + query (str, optional): The query to be executed with pyRFC. + sep (str, optional): The separator to use when reading a CSV file. Defaults to "\t". + autopick_sep (str, optional): Whether SAPRFC should try different separators in case + the query fails with the default one. + """ + + if query is None: + raise ValueError("Please provide the query.") + + sap = SAPRFC(sep=sep, autopick_sep=autopick_sep, credentials=credentials) + sap.query(query) + + self.logger.info(f"Downloading data from SAP to a DataFrame...") + self.logger.debug(f"Running query: \n{query}.") + + df = sap.to_df() + + self.logger.info(f"Data has been downloaded successfully.") + return df diff --git a/viadot/tasks/sharepoint.py b/viadot/tasks/sharepoint.py index 9ae3b2d82..712d3e7be 100644 --- a/viadot/tasks/sharepoint.py +++ b/viadot/tasks/sharepoint.py @@ -1,13 +1,16 @@ from typing import List import os import copy +import json import pandas as pd from prefect import Task from prefect.utilities.tasks import defaults_from_attrs from prefect.utilities import logging +from prefect.tasks.secrets import PrefectSecret from ..exceptions import ValidationError from ..sources import Sharepoint +from .azure_key_vault import AzureKeyVaultSecret logger = logging.get_logger() @@ -79,6 +82,16 @@ def check_column_names( return df_header_list def df_replace_special_chars(self, df: pd.DataFrame): + """ + Replace "\n" and "\t" with "". + + Args: + df (pd.DataFrame): Pandas data frame to replace characters. + + Returns: + df (pd.DataFrame): Pandas data frame + + """ return df.replace(r"\n|\t", "", regex=True) def split_sheet( @@ -137,6 +150,8 @@ def run( nrows: int = 50000, validate_excel_file: bool = False, sheet_number: int = None, + credentials_secret: str = None, + vault_name: str = None, **kwargs, ) -> None: """ @@ -148,16 +163,32 @@ def run( nrows (int, optional): Number of rows to read at a time. Defaults to 50000. sheet_number (int): Sheet number to be extracted from file. Counting from 0, if None all sheets are axtracted. Defaults to None. validate_excel_file (bool, optional): Check if columns in separate sheets are the same. Defaults to False. + credentials_secret (str, optional): The name of the Azure Key Vault secret containing a dictionary with + ACCOUNT_NAME and Service Principal credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET). Defaults to None. + vault_name (str, optional): The name of the vault from which to obtain the secret. Defaults to None. Returns: pd.DataFrame: Pandas data frame """ + if not credentials_secret: + # attempt to read a default for the service principal secret name + try: + credentials_secret = PrefectSecret("SHAREPOINT_KV").run() + except ValueError: + pass + + if credentials_secret: + credentials_str = AzureKeyVaultSecret( + credentials_secret, vault_name=vault_name + ).run() + credentials = json.loads(credentials_str) + self.path_to_file = path_to_file self.url_to_file = url_to_file path_to_file = os.path.basename(self.path_to_file) self.sheet_number = sheet_number - s = Sharepoint(download_from_path=self.url_to_file) + s = Sharepoint(download_from_path=self.url_to_file, credentials=credentials) s.download_file(download_to_path=path_to_file) self.nrows = nrows diff --git a/viadot/tasks/sqlite.py b/viadot/tasks/sqlite.py index 3862d8d21..2faf54cd3 100644 --- a/viadot/tasks/sqlite.py +++ b/viadot/tasks/sqlite.py @@ -2,6 +2,7 @@ from typing import Any, Dict import pandas as pd +from pendulum import instance import prefect from prefect import Task from prefect.utilities.tasks import defaults_from_attrs @@ -54,7 +55,13 @@ def run( sqlite.create_table( table=table_name, schema=schema, dtypes=dtypes, if_exists=if_exists ) - sqlite.insert_into(table=table_name, df=df) + logger = prefect.context.get("logger") + if isinstance(df, pd.DataFrame) == False: + logger.warning("Object is not a pandas DataFrame") + elif df.empty: + logger.warning("DataFrame is empty") + else: + sqlite.insert_into(table=table_name, df=df) return True diff --git a/viadot/tasks/supermetrics.py b/viadot/tasks/supermetrics.py index 2ab63e1de..93f30e4a3 100644 --- a/viadot/tasks/supermetrics.py +++ b/viadot/tasks/supermetrics.py @@ -9,6 +9,22 @@ class SupermetricsToCSV(Task): + """ + Task to downloading data from Supermetrics API to CSV file. + + Args: + path (str, optional): The destination path. Defaults to "supermetrics_extract.csv". + max_retries (int, optional): The maximum number of retries. Defaults to 5. + retry_delay (timedelta, optional): The delay between task retries. Defaults to 10 seconds. + timeout (int, optional): Task timeout. Defaults to 30 minuntes. + max_rows (int, optional): Maximum number of rows the query results should contain. Defaults to 1 000 000. + max_cols (int, optional): Maximum number of columns the query results should contain. Defaults to None. + if_exists (str, optional): What to do if file already exists. Defaults to "replace". + if_empty (str, optional): What to do if query returns no data. Defaults to "warn". + sep (str, optional): The separator in a target csv file. Defaults to "/t". + + """ + def __init__( self, *args, @@ -75,6 +91,33 @@ def run( sep: str = None, ): + """ + Task run method. + + Args: + path (str, optional): The destination path. Defaulrs to None + ds_id (str, optional): A Supermetrics query parameter. + ds_accounts (Union[str, List[str]], optional): A Supermetrics query parameter. Defaults to None. + ds_segments (List[str], optional): A Supermetrics query parameter. Defaults to None. + ds_user (str, optional): A Supermetrics query parameter. Defaults to None. + fields (List[str], optional): A Supermetrics query parameter. Defaults to None. + date_range_type (str, optional): A Supermetrics query parameter. Defaults to None. + start_date (str, optional): A Supermetrics query parameter. Defaults to None. + end_date (str, optional) A Supermetrics query parameter. Defaults to None. + settings (Dict[str, Any], optional): A Supermetrics query parameter. Defaults to None. + filter (str, optional): A Supermetrics query parameter. Defaults to None. + max_rows (int, optional): A Supermetrics query parameter. Defaults to None. + max_columns (int, optional): A Supermetrics query parameter. Defaults to None. + order_columns (str, optional): A Supermetrics query parameter. Defaults to None. + if_exists (str, optional): What to do if file already exists. Defaults to "replace". + if_empty (str, optional): What to do if query returns no data. Defaults to "warn". + max_retries (int, optional): The maximum number of retries. Defaults to 5. + retry_delay (timedelta, optional): The delay between task retries. Defaults to 10 seconds. + timeout (int, optional): Task timeout. Defaults to 30 minuntes. + sep (str, optional) + + """ + if max_retries: self.max_retries = max_retries diff --git a/viadot/utils.py b/viadot/utils.py index b444a1189..671ea61d4 100644 --- a/viadot/utils.py +++ b/viadot/utils.py @@ -24,7 +24,7 @@ def handle_api_response( url (str): the URL which trying to connect. auth (tuple, optional): authorization information. Defaults to None. params (Dict[str, Any], optional): the request params also includes parameters such as the content type. Defaults to None. - headers: Dict[str, Any], optional): the request headers required by Supermetrics API. + headers: (Dict[str, Any], optional): the request headers. Defaults to None. timeout (tuple, optional): the request times out. Defaults to (3.05, 60 * 30). Raises: @@ -34,7 +34,7 @@ def handle_api_response( APIError: defined by user. Returns: - response + requests.models.Response """ try: session = requests.Session()