diff --git a/pyproject.toml b/pyproject.toml index 0e2df72c4..c08ebf452 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "paramiko==2.11.0", "defusedxml>=0.7.1", "aiohttp>=3.10.5", + "simple-salesforce==1.12.6", ] requires-python = ">=3.10" readme = "README.md" diff --git a/requirements-dev.lock b/requirements-dev.lock index d4e0a3138..e379c03fa 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -7,6 +7,7 @@ # all-features: false # with-sources: false # generate-hashes: false +# universal: false -e file:. aiohappyeyeballs==2.4.0 @@ -40,6 +41,7 @@ attrs==24.2.0 # via jsonschema # via referencing # via visions + # via zeep babel==2.16.0 # via mkdocs-material bcrypt==4.2.0 @@ -96,12 +98,14 @@ comm==0.2.2 coolname==2.2.0 # via prefect coverage==7.6.1 + # via coverage croniter==2.0.7 # via prefect cryptography==43.0.0 # via moto # via paramiko # via prefect + # via pyjwt cssselect2==0.7.0 # via cairosvg dateparser==1.2.0 @@ -173,6 +177,7 @@ httpcore==1.0.5 # via httpx # via prefect httpx==0.27.0 + # via httpx # via neoteroi-mkdocs # via prefect humanize==4.10.0 @@ -199,6 +204,8 @@ ipykernel==6.29.5 # via mkdocs-jupyter ipython==8.26.0 # via ipykernel +isodate==0.6.1 + # via zeep itsdangerous==2.2.0 # via prefect jedi==0.19.1 @@ -248,6 +255,8 @@ kubernetes==29.0.0 loguru==0.7.2 lumacli==0.1.2 # via viadot2 +lxml==5.3.0 + # via zeep mako==1.3.5 # via alembic markdown==3.7 @@ -306,14 +315,18 @@ mkdocs-include-markdown-plugin==6.2.2 mkdocs-jupyter==0.24.8 mkdocs-material==9.5.32 # via mkdocs-jupyter + # via mkdocs-material mkdocs-material-extensions==1.3.1 # via mkdocs-material mkdocs-mermaid2-plugin==1.1.1 mkdocs-table-reader-plugin==3.0.1 mkdocstrings==0.25.2 + # via mkdocstrings # via mkdocstrings-python mkdocstrings-python==1.10.5 # via mkdocstrings +more-itertools==10.5.0 + # via simple-salesforce moto==5.0.13 multidict==6.0.5 # via aiohttp @@ -385,6 +398,7 @@ platformdirs==4.2.2 # via jupyter-core # via mkdocs-get-deps # via mkdocstrings + # via zeep pluggy==1.5.0 # via pytest prefect==2.20.2 @@ -428,6 +442,8 @@ pygments==2.18.0 # via mkdocs-material # via nbconvert # via rich +pyjwt==2.9.0 + # via simple-salesforce pymdown-extensions==10.9 # via mkdocs-material # via mkdocs-mermaid2-plugin @@ -471,6 +487,7 @@ pytz==2024.1 # via pandas # via prefect # via trino + # via zeep pytzdata==2020.1 # via pendulum pywavelets==1.7.0 @@ -512,15 +529,23 @@ requests==2.32.3 # via mkdocs-mermaid2-plugin # via moto # via o365 + # via requests-file # via requests-oauthlib + # via requests-toolbelt # via responses # via sharepy + # via simple-salesforce # via trino # via viadot2 + # via zeep +requests-file==2.1.0 + # via zeep requests-oauthlib==2.0.0 # via apprise # via kubernetes # via o365 +requests-toolbelt==1.0.0 + # via zeep responses==0.25.3 # via moto rfc3339-validator==0.1.4 @@ -546,8 +571,6 @@ scipy==1.14.0 # via imagehash sendgrid==6.11.0 # via viadot2 -setuptools==73.0.0 - # via mkdocs-mermaid2-plugin sgqlc==16.3 # via prefect-github shapely==2.0.6 @@ -556,9 +579,12 @@ sharepy==2.0.0 # via viadot2 shellingham==1.5.4 # via typer +simple-salesforce==1.12.6 + # via viadot2 six==1.16.0 # via asttokens # via bleach + # via isodate # via jsbeautifier # via kubernetes # via paramiko @@ -618,6 +644,7 @@ trino==0.328.0 typer==0.12.4 # via lumacli # via prefect + # via typer typing-extensions==4.12.2 # via aiosqlite # via alembic @@ -626,6 +653,7 @@ typing-extensions==4.12.2 # via prefect # via pydantic # via pydantic-core + # via simple-salesforce # via sqlalchemy # via typer # via uvicorn @@ -670,5 +698,9 @@ xmltodict==0.13.0 # via moto yarl==1.9.4 # via aiohttp +zeep==4.2.1 + # via simple-salesforce zipp==3.20.0 # via importlib-metadata +setuptools==73.0.0 + # via mkdocs-mermaid2-plugin diff --git a/requirements.lock b/requirements.lock index 8fbaf36af..6c8e8c76e 100644 --- a/requirements.lock +++ b/requirements.lock @@ -7,6 +7,7 @@ # all-features: false # with-sources: false # generate-hashes: false +# universal: false -e file:. aiohappyeyeballs==2.4.0 @@ -38,6 +39,7 @@ attrs==24.2.0 # via jsonschema # via referencing # via visions + # via zeep bcrypt==4.2.0 # via paramiko beautifulsoup4==4.12.3 @@ -73,6 +75,7 @@ croniter==2.0.7 cryptography==43.0.0 # via paramiko # via prefect + # via pyjwt dateparser==1.2.0 # via prefect defusedxml==0.7.1 @@ -117,6 +120,7 @@ httpcore==1.0.5 # via httpx # via prefect httpx==0.27.0 + # via httpx # via prefect humanize==4.10.0 # via jinja2-humanize-extension @@ -135,6 +139,8 @@ importlib-resources==6.1.3 # via prefect iniconfig==2.0.0 # via pytest +isodate==0.6.1 + # via zeep itsdangerous==2.2.0 # via prefect jinja2==3.1.4 @@ -154,6 +160,8 @@ kubernetes==29.0.0 # via prefect lumacli==0.1.2 # via viadot2 +lxml==5.3.0 + # via zeep mako==1.3.5 # via alembic markdown==3.7 @@ -165,6 +173,8 @@ markupsafe==2.1.5 # via mako mdurl==0.1.2 # via markdown-it-py +more-itertools==10.5.0 + # via simple-salesforce multidict==6.0.5 # via aiohttp # via yarl @@ -204,6 +214,8 @@ pendulum==2.1.2 # via prefect pillow==10.4.0 # via imagehash +platformdirs==4.3.6 + # via zeep pluggy==1.5.0 # via pytest prefect==2.20.2 @@ -235,6 +247,8 @@ pygit2==1.14.1 # via viadot2 pygments==2.18.0 # via rich +pyjwt==2.9.0 + # via simple-salesforce pynacl==1.5.0 # via paramiko pyodbc==5.1.0 @@ -264,6 +278,7 @@ pytz==2024.1 # via pandas # via prefect # via trino + # via zeep pytzdata==2020.1 # via pendulum pywavelets==1.7.0 @@ -286,14 +301,22 @@ requests==2.32.3 # via kubernetes # via lumacli # via o365 + # via requests-file # via requests-oauthlib + # via requests-toolbelt # via sharepy + # via simple-salesforce # via trino # via viadot2 + # via zeep +requests-file==2.1.0 + # via zeep requests-oauthlib==2.0.0 # via apprise # via kubernetes # via o365 +requests-toolbelt==1.0.0 + # via zeep rfc3339-validator==0.1.4 # via prefect rich==13.7.1 @@ -321,7 +344,10 @@ sharepy==2.0.0 # via viadot2 shellingham==1.5.4 # via typer +simple-salesforce==1.12.6 + # via viadot2 six==1.16.0 + # via isodate # via kubernetes # via paramiko # via python-dateutil @@ -355,6 +381,7 @@ trino==0.328.0 typer==0.12.4 # via lumacli # via prefect + # via typer typing-extensions==4.12.2 # via aiosqlite # via alembic @@ -362,6 +389,7 @@ typing-extensions==4.12.2 # via prefect # via pydantic # via pydantic-core + # via simple-salesforce # via sqlalchemy # via typer # via uvicorn @@ -388,3 +416,5 @@ websockets==12.0 # via prefect yarl==1.9.8 # via aiohttp +zeep==4.2.1 + # via simple-salesforce diff --git a/src/viadot/orchestration/prefect/flows/__init__.py b/src/viadot/orchestration/prefect/flows/__init__.py index 946239b37..7687baaa5 100644 --- a/src/viadot/orchestration/prefect/flows/__init__.py +++ b/src/viadot/orchestration/prefect/flows/__init__.py @@ -15,6 +15,7 @@ from .mediatool_to_adls import mediatool_to_adls from .mindful_to_adls import mindful_to_adls from .outlook_to_adls import outlook_to_adls +from .salesforce_to_adls import salesforce_to_adls from .sap_to_parquet import sap_to_parquet from .sap_to_redshift_spectrum import sap_to_redshift_spectrum from .sftp_to_adls import sftp_to_adls @@ -45,6 +46,7 @@ "mediatool_to_adls", "mindful_to_adls", "outlook_to_adls", + "salesforce_to_adls", "sap_to_parquet", "sap_to_redshift_spectrum", "sftp_to_adls", diff --git a/src/viadot/orchestration/prefect/flows/salesforce_to_adls.py b/src/viadot/orchestration/prefect/flows/salesforce_to_adls.py new file mode 100644 index 000000000..cf45505a4 --- /dev/null +++ b/src/viadot/orchestration/prefect/flows/salesforce_to_adls.py @@ -0,0 +1,79 @@ +"""Download data from Salesforce API to Azure Data Lake Storage.""" + +from prefect import flow +from prefect.task_runners import ConcurrentTaskRunner + +from viadot.orchestration.prefect.tasks import df_to_adls, salesforce_to_df + + +@flow( + name="Salesforce extraction to ADLS", + description="Extract data from Salesforce and load " + + "it into Azure Data Lake Storage.", + retries=1, + retry_delay_seconds=60, + task_runner=ConcurrentTaskRunner, +) +def salesforce_to_adls( # noqa: PLR0913 + config_key: str | None = None, + azure_key_vault_secret: str | None = None, + env: str | None = None, + domain: str | None = None, + client_id: str | None = None, + query: str | None = None, + table: str | None = None, + columns: list[str] | None = None, + adls_config_key: str | None = None, + adls_azure_key_vault_secret: str | None = None, + adls_path: str | None = None, + adls_path_overwrite: bool = False, +) -> None: + """Flow to download data from Salesforce API to Azure Data Lake. + + Args: + config_key (str, optional): The key in the viadot config holding relevant + credentials. Defaults to None. + azure_key_vault_secret (str, optional): The name of the Azure Key Vault secret + where credentials are stored. Defaults to None. + env (str, optional): Environment information, provides information about + credential and connection configuration. Defaults to 'DEV'. + domain (str, optional): Domain of a connection. defaults to 'test' (sandbox). + Can only be added if built-in username/password/security token is provided. + Defaults to None. + client_id (str, optional): Client id to keep the track of API calls. + Defaults to None. + query (str, optional): Query for download the data if specific download is + needed. Defaults to None. + table (str, optional): Table name. Can be used instead of query. + Defaults to None. + columns (list[str], optional): List of columns which are needed - table + argument is needed. Defaults to None. + adls_config_key (str, optional): The key in the viadot config holding + relevant credentials. Defaults to None. + adls_azure_key_vault_secret (str, optional): The name of the Azure Key + Vault secret containing a dictionary with ACCOUNT_NAME and Service Principal + credentials (TENANT_ID, CLIENT_ID, CLIENT_SECRET) for the Azure Data Lake. + Defaults to None. + adls_path (str, optional): Azure Data Lake destination file path + (with file name). Defaults to None. + adls_path_overwrite (bool, optional): Whether to overwrite the file in ADLS. + Defaults to True. + """ + data_frame = salesforce_to_df( + config_key=config_key, + azure_key_vault_secret=azure_key_vault_secret, + env=env, + domain=domain, + client_id=client_id, + query=query, + table=table, + columns=columns, + ) + + return df_to_adls( + df=data_frame, + path=adls_path, + credentials_secret=adls_azure_key_vault_secret, + config_key=adls_config_key, + overwrite=adls_path_overwrite, + ) diff --git a/src/viadot/orchestration/prefect/tasks/__init__.py b/src/viadot/orchestration/prefect/tasks/__init__.py index 9382594e8..02b8669c3 100644 --- a/src/viadot/orchestration/prefect/tasks/__init__.py +++ b/src/viadot/orchestration/prefect/tasks/__init__.py @@ -19,6 +19,7 @@ from .outlook import outlook_to_df from .redshift_spectrum import df_to_redshift_spectrum from .s3 import s3_upload_file +from .salesforce import salesforce_to_df from .sap_rfc import sap_rfc_to_df from .sftp import sftp_list, sftp_to_df from .sharepoint import sharepoint_download_file, sharepoint_to_df @@ -49,6 +50,7 @@ "mindful_to_df", "outlook_to_df", "s3_upload_file", + "salesforce_to_df", "sap_rfc_to_df", "sftp_list", "sftp_to_df", diff --git a/src/viadot/orchestration/prefect/tasks/salesforce.py b/src/viadot/orchestration/prefect/tasks/salesforce.py new file mode 100644 index 000000000..b2c397f27 --- /dev/null +++ b/src/viadot/orchestration/prefect/tasks/salesforce.py @@ -0,0 +1,65 @@ +"""Task to download data from Salesforce API into a Pandas DataFrame.""" + +import pandas as pd +from prefect import task + +from viadot.orchestration.prefect.exceptions import MissingSourceCredentialsError +from viadot.orchestration.prefect.utils import get_credentials +from viadot.sources import Salesforce + + +@task(retries=3, log_prints=True, retry_delay_seconds=10, timeout_seconds=60 * 60) +def salesforce_to_df( + config_key: str | None = None, + azure_key_vault_secret: str | None = None, + env: str | None = None, + domain: str | None = None, + client_id: str | None = None, + query: str | None = None, + table: str | None = None, + columns: list[str] | None = None, +) -> pd.DataFrame: + """Querying Salesforce and saving data as the data frame. + + Args: + config_key (str, optional): The key in the viadot config holding relevant + credentials. Defaults to None. + azure_key_vault_secret (str, optional): The name of the Azure Key Vault secret + where credentials are stored. Defaults to None. + env (str, optional): Environment information, provides information about + credential and connection configuration. Defaults to 'DEV'. + domain (str, optional): Domain of a connection. defaults to 'test' (sandbox). + Can only be added if built-in username/password/security token is provided. + Defaults to None. + client_id (str, optional): Client id to keep the track of API calls. + Defaults to None. + query (str, optional): Query for download the data if specific download is + needed. Defaults to None. + table (str, optional): Table name. Can be used instead of query. + Defaults to None. + columns (list[str], optional): List of columns which are needed - table + argument is needed. Defaults to None. + + Returns: + pd.DataFrame: The response data as a pandas DataFrame. + """ + if not (azure_key_vault_secret or config_key): + raise MissingSourceCredentialsError + + if not config_key: + credentials = get_credentials(azure_key_vault_secret) + + salesforce = Salesforce( + credentials=credentials, + config_key=config_key, + env=env, + domain=domain, + client_id=client_id, + ) + salesforce.api_connection( + query=query, + table=table, + columns=columns, + ) + + return salesforce.to_df() diff --git a/src/viadot/sources/__init__.py b/src/viadot/sources/__init__.py index 70a374ae7..48e3184f8 100644 --- a/src/viadot/sources/__init__.py +++ b/src/viadot/sources/__init__.py @@ -13,6 +13,7 @@ from .mediatool import Mediatool from .mindful import Mindful from .outlook import Outlook +from .salesforce import Salesforce from .sftp import Sftp from .sharepoint import Sharepoint from .sql_server import SQLServer @@ -32,6 +33,7 @@ "Sftp", "Outlook", "SQLServer", + "Salesforce", "Sharepoint", "Supermetrics", "SupermetricsCredentials", # pragma: allowlist-secret diff --git a/src/viadot/sources/salesforce.py b/src/viadot/sources/salesforce.py new file mode 100644 index 000000000..c5c0034c1 --- /dev/null +++ b/src/viadot/sources/salesforce.py @@ -0,0 +1,150 @@ +"""Salesforce API connector.""" + +from typing import Literal + +import pandas as pd +from pydantic import BaseModel +from simple_salesforce import Salesforce as SimpleSalesforce + +from viadot.config import get_source_credentials +from viadot.exceptions import CredentialError +from viadot.sources.base import Source +from viadot.utils import add_viadot_metadata_columns + + +class SalesforceCredentials(BaseModel): + """Checking for values in Salesforce credentials dictionary. + + Two key values are held in the Salesforce connector: + - username: The unique name for the organization. + - password: The unique passwrod for the organization. + - token: A unique token to be used as the password for API requests. + + Args: + BaseModel (pydantic.main.ModelMetaclass): A base class for creating + Pydantic models. + """ + + username: str + password: str + token: str + + +class Salesforce(Source): + """Class implementing the Salesforce API. + + Documentation for this API is available at: + https://developer.salesforce.com/docs/apis. + """ + + def __init__( + self, + *args, + credentials: SalesforceCredentials | None = None, + config_key: str = "salesforce", + env: Literal["DEV", "QA", "PROD"] = "DEV", + domain: str = "test", + client_id: str = "viadot", + **kwargs, + ): + """A class for downloading data from Salesforce. + + Args: + credentials (dict(str, any), optional): Salesforce credentials as a + dictionary. Defaults to None. + config_key (str, optional): The key in the viadot config holding relevant + credentials. Defaults to "salesforce". + env (Literal["DEV", "QA", "PROD"], optional): Environment information, + provides information about credential and connection configuration. + Defaults to 'DEV'. + domain (str, optional): Domain of a connection. Defaults to 'test' + (sandbox). Can only be add if a username/password/security token + is provide. + client_id (str, optional): Client id, keep track of API calls. + Defaults to 'viadot'. + """ + credentials = credentials or get_source_credentials(config_key) + + if not ( + credentials.get("username") + and credentials.get("password") + and credentials.get("token") + ): + message = "'username', 'password' and 'token' credentials are required." + raise CredentialError(message) + + validated_creds = dict(SalesforceCredentials(**credentials)) + super().__init__(*args, credentials=validated_creds, **kwargs) + + if env.upper() == "DEV" or env.upper() == "QA": + self.salesforce = SimpleSalesforce( + username=self.credentials["username"], + password=self.credentials["password"], + security_token=self.credentials["token"], + domain=domain, + client_id=client_id, + ) + + elif env.upper() == "PROD": + self.salesforce = SimpleSalesforce( + username=self.credentials["username"], + password=self.credentials["password"], + security_token=self.credentials["token"], + ) + + else: + message = "The only available environments are DEV, QA, and PROD." + raise ValueError(message) + + self.data = None + + def api_connection( + self, + query: str | None = None, + table: str | None = None, + columns: list[str] | None = None, + ) -> None: + """General method to connect to Salesforce API and generate the response. + + Args: + query (str, optional): The query to be used to download the data. + Defaults to None. + table (str, optional): Table name. Defaults to None. + columns (list[str], optional): List of required columns. Requires `table` + to be specified. Defaults to None. + """ + if not query: + columns_str = ", ".join(columns) if columns else "FIELDS(STANDARD)" + query = f"SELECT {columns_str} FROM {table}" # noqa: S608 + + self.data = self.salesforce.query(query).get("records") + + # Remove metadata from the data + for record in self.data: + record.pop("attributes") + + @add_viadot_metadata_columns + def to_df( + self, + if_empty: str = "fail", + ) -> pd.DataFrame: + """Downloads the indicated data and returns the DataFrame. + + Args: + if_empty (str, optional): What to do if a fetch produce no data. + Defaults to "warn + + Returns: + pd.DataFrame: Selected rows from Salesforce. + """ + df = pd.DataFrame(self.data) + + if df.empty: + self._handle_if_empty( + if_empty=if_empty, + message="The response does not contain any data.", + ) + else: + self.logger.info("Successfully downloaded data from the Mindful API.") + + return df diff --git a/tests/integration/orchestration/prefect/flows/test_salesforce.py b/tests/integration/orchestration/prefect/flows/test_salesforce.py new file mode 100644 index 000000000..2eb44a822 --- /dev/null +++ b/tests/integration/orchestration/prefect/flows/test_salesforce.py @@ -0,0 +1,19 @@ +"""'test_salesforce.py'.""" + +from viadot.orchestration.prefect.flows import salesforce_to_adls + + +def test_salesforce_to_adls( + azure_key_vault_secret, adls_path, adls_azure_key_vault_secret +): + """Test Salesforce prefect flow.""" + state = salesforce_to_adls( + azure_key_vault_secret=azure_key_vault_secret, + env="dev", + table="Contact", + adls_path=adls_path, + adls_azure_key_vault_secret=adls_azure_key_vault_secret, + adls_path_overwrite=True, + ) + all_successful = all(s.type == "COMPLETED" for s in state) + assert all_successful, "Not all tasks in the flow completed successfully." diff --git a/tests/unit/test_salesforce.py b/tests/unit/test_salesforce.py new file mode 100644 index 000000000..0a70145b2 --- /dev/null +++ b/tests/unit/test_salesforce.py @@ -0,0 +1,184 @@ +"""'test_salesforce.py'.""" + +import pytest +from simple_salesforce import Salesforce as SimpleSalesforce + +from viadot.exceptions import CredentialError +from viadot.sources import Salesforce +from viadot.sources.salesforce import SalesforceCredentials + + +variables = { + "credentials": { + "username": "test_user", + "password": "test_password", # pragma: allowlist secret + "token": "test_token", + }, + "records_1": [ + { + "Id": "001", + "Name": "Test Record", + "attributes": { + "type": "Account", + "url": "/services/data/v50.0/sobjects/Account/001", + }, + }, + ], + "records_2": [ + { + "Id": "001", + "Name": "Test Record", + "attributes": { + "type": "Account", + "url": "/services/data/v50.0/sobjects/Account/001", + }, + }, + ], + "data": [ + {"Id": "001", "Name": "Test Record"}, + {"Id": "002", "Name": "Another Record"}, + ], +} + + +@pytest.mark.basic +def test_salesforce_init_dev_env(mocker): + """Test Salesforce, starting in dev mode.""" + mock_sf_instance = mocker.MagicMock(spec=SimpleSalesforce) + mocker.patch( + "viadot.sources.salesforce.SimpleSalesforce", return_value=mock_sf_instance + ) + sf_instance = Salesforce(credentials=variables["credentials"], env="DEV") + + assert sf_instance.salesforce == mock_sf_instance + + +class TestSalesforceCredentials: + """Test Salesforce Credentials Class.""" + + @pytest.mark.basic + def test_salesforce_credentials(self): + """Test Salesforce credentials.""" + SalesforceCredentials( + username="test_user", + password="test_password", # noqa: S106 # pragma: allowlist secret + token="test_token", # noqa: S106 + ) + + +@pytest.mark.basic +def test_salesforce_init_prod_env(mocker): + """Test Salesforce, starting in prod mode.""" + mock_sf_instance = mocker.MagicMock(spec=SimpleSalesforce) + mocker.patch( + "viadot.sources.salesforce.SimpleSalesforce", return_value=mock_sf_instance + ) + sf_instance = Salesforce(credentials=variables["credentials"], env="PROD") + + assert sf_instance.salesforce == mock_sf_instance + + +@pytest.mark.basic +def test_salesforce_invalid_env(): + """Test Salesforce, invalid `env` parameter.""" + with pytest.raises( + ValueError, match="The only available environments are DEV, QA, and PROD." + ): + Salesforce(credentials=variables["credentials"], env="INVALID") + + +@pytest.mark.basic +def test_salesforce_missing_credentials(): + """Test Salesforce missing credentials.""" + incomplete_creds = { + "username": "user", # pragma: allowlist secret + "password": "pass", # pragma: allowlist secret + } + with pytest.raises(CredentialError): + Salesforce(credentials=incomplete_creds) + + +@pytest.mark.connect +def test_salesforce_api_connection(mocker): + """Test Salesforce `api_connection` method with a query.""" + mock_sf_instance = mocker.MagicMock(spec=SimpleSalesforce) + mocker.patch( + "viadot.sources.salesforce.SimpleSalesforce", return_value=mock_sf_instance + ) + salesforce_instance = Salesforce(credentials=variables["credentials"]) + + mock_sf_instance.query.return_value = {"records": variables["records_1"]} + + salesforce_instance.api_connection(query="SELECT Id, Name FROM Account") + + assert salesforce_instance.data == [{"Id": "001", "Name": "Test Record"}] + mock_sf_instance.query.assert_called_once_with("SELECT Id, Name FROM Account") + + +@pytest.mark.connect +def test_salesforce_api_connection_with_columns(mocker): + """Test Salesforce `api_connection` method with columns.""" + mock_sf_instance = mocker.MagicMock(spec=SimpleSalesforce) + mocker.patch( + "viadot.sources.salesforce.SimpleSalesforce", return_value=mock_sf_instance + ) + salesforce_instance = Salesforce(credentials=variables["credentials"]) + + mock_sf_instance.query.return_value = {"records": variables["records_2"]} + + salesforce_instance.api_connection(table="Account", columns=["Id", "Name"]) + + assert salesforce_instance.data == [{"Id": "001", "Name": "Test Record"}] + mock_sf_instance.query.assert_called_once_with("SELECT Id, Name FROM Account") + + +@pytest.mark.functions +def test_salesforce_to_df(mocker): + """Test Salesforce `to_df` method.""" + mock_sf_instance = mocker.MagicMock(spec=SimpleSalesforce) + mocker.patch( + "viadot.sources.salesforce.SimpleSalesforce", return_value=mock_sf_instance + ) + salesforce_instance = Salesforce(credentials=variables["credentials"]) + salesforce_instance.data = variables["data"] + + df = salesforce_instance.to_df() + + assert not df.empty + assert df.shape == (2, 4) + assert list(df.columns) == [ + "Id", + "Name", + "_viadot_source", + "_viadot_downloaded_at_utc", + ] + assert df.iloc[0]["Id"] == "001" + + +@pytest.mark.functions +def test_salesforce_to_df_empty_data(mocker): + """Test Salesforce `to_df` method with empty df.""" + mock_sf_instance = mocker.MagicMock(spec=SimpleSalesforce) + mocker.patch( + "viadot.sources.salesforce.SimpleSalesforce", return_value=mock_sf_instance + ) + salesforce_instance = Salesforce(credentials=variables["credentials"]) + salesforce_instance.data = [] + + with pytest.raises(ValueError, match="The response does not contain any data."): + salesforce_instance.to_df(if_empty="fail") + + +@pytest.mark.functions +def test_salesforce_to_df_warn_empty_data(mocker): + """Test Salesforce `to_df` method with empty df, warn.""" + mock_sf_instance = mocker.MagicMock(spec=SimpleSalesforce) + mocker.patch( + "viadot.sources.salesforce.SimpleSalesforce", return_value=mock_sf_instance + ) + salesforce_instance = Salesforce(credentials=variables["credentials"]) + salesforce_instance.data = [] + + df = salesforce_instance.to_df(if_empty="warn") + + assert df.empty