From 009de792cf88be829486b93dc5cb41aea24d0580 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 23 Oct 2023 15:28:56 +0800 Subject: [PATCH] Fix azure provider breaking change (#1341) * fix(microsoft): make resource_group_name and factory_name required argument * fix(microsoft): add missing types for fixing example_adf_run_pipeline.py * this is added in newer version of datafactory lib * fix(azure): remove deprecated pipelineinfo --- .circleci/config.yml | 2 +- .../example_dags/example_adf_run_pipeline.py | 22 +++++++++++-------- .../microsoft/azure/hooks/data_factory.py | 12 +++++----- .../microsoft/azure/operators/data_factory.py | 6 ++--- .../microsoft/azure/triggers/data_factory.py | 10 ++++----- setup.cfg | 1 + .../azure/operators/test_data_factory.py | 2 ++ .../azure/sensors/test_data_factory.py | 8 ++++++- 8 files changed, 37 insertions(+), 26 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 753da3469..12d991eb4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -160,7 +160,7 @@ jobs: sudo apt-get install libsasl2-dev - run: name: Install Dependencies - command: pip install -U -e .[all,docs,mypy] + command: pip install -U -e .[docs] - run: name: Run Sphinx command: | diff --git a/astronomer/providers/microsoft/azure/example_dags/example_adf_run_pipeline.py b/astronomer/providers/microsoft/azure/example_dags/example_adf_run_pipeline.py index 189fa736c..47f9fc68b 100644 --- a/astronomer/providers/microsoft/azure/example_dags/example_adf_run_pipeline.py +++ b/astronomer/providers/microsoft/azure/example_dags/example_adf_run_pipeline.py @@ -40,8 +40,6 @@ default_args = { "execution_timeout": timedelta(hours=EXECUTION_TIMEOUT), "azure_data_factory_conn_id": "azure_data_factory_default", - "factory_name": DATAFACTORY_NAME, # This can also be specified in the ADF connection. - "resource_group_name": RESOURCE_GROUP_NAME, # This can also be specified in the ADF connection. "retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)), "retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))), } @@ -88,7 +86,7 @@ def create_adf_storage_pipeline() -> None: df_resource = Factory(location=LOCATION) df = adf_client.factories.create_or_update(RESOURCE_GROUP_NAME, DATAFACTORY_NAME, df_resource) while df.provisioning_state != "Succeeded": - df = adf_client.factories.get(RESOURCE_GROUP_NAME, DATAFACTORY_NAME) + df = adf_client.factories.get(RESOURCE_GROUP_NAME, DATAFACTORY_NAME) # type: ignore[assignment] time.sleep(1) # Create an Azure Storage linked service @@ -97,17 +95,17 @@ def create_adf_storage_pipeline() -> None: storage_string = SecureString(value=CONNECTION_STRING) ls_azure_storage = LinkedServiceResource( - properties=AzureStorageLinkedService(connection_string=storage_string) + properties=AzureStorageLinkedService(connection_string=storage_string) # type: ignore[arg-type] ) adf_client.linked_services.create_or_update( RESOURCE_GROUP_NAME, DATAFACTORY_NAME, STORAGE_LINKED_SERVICE_NAME, ls_azure_storage ) # Create an Azure blob dataset (input) - ds_ls = LinkedServiceReference(reference_name=STORAGE_LINKED_SERVICE_NAME) + ds_ls = LinkedServiceReference(type="LinkedServiceReference", reference_name=STORAGE_LINKED_SERVICE_NAME) ds_azure_blob = DatasetResource( properties=AzureBlobDataset( - linked_service_name=ds_ls, folder_path=BLOB_PATH, file_name=BLOB_FILE_NAME + linked_service_name=ds_ls, folder_path=BLOB_PATH, file_name=BLOB_FILE_NAME # type: ignore[arg-type] ) ) adf_client.datasets.create_or_update( @@ -116,7 +114,7 @@ def create_adf_storage_pipeline() -> None: # Create an Azure blob dataset (output) ds_out_azure_blob = DatasetResource( - properties=AzureBlobDataset(linked_service_name=ds_ls, folder_path=OUTPUT_BLOB_PATH) + properties=AzureBlobDataset(linked_service_name=ds_ls, folder_path=OUTPUT_BLOB_PATH) # type: ignore[arg-type] ) adf_client.datasets.create_or_update( RESOURCE_GROUP_NAME, DATAFACTORY_NAME, DATASET_OUTPUT_NAME, ds_out_azure_blob @@ -125,8 +123,8 @@ def create_adf_storage_pipeline() -> None: # Create a copy activity blob_source = BlobSource() blob_sink = BlobSink() - ds_in_ref = DatasetReference(reference_name=DATASET_INPUT_NAME) - ds_out_ref = DatasetReference(reference_name=DATASET_OUTPUT_NAME) + ds_in_ref = DatasetReference(type="DatasetReference", reference_name=DATASET_INPUT_NAME) + ds_out_ref = DatasetReference(type="DatasetReference", reference_name=DATASET_OUTPUT_NAME) copy_activity = CopyActivity( name=ACTIVITY_NAME, inputs=[ds_in_ref], outputs=[ds_out_ref], source=blob_source, sink=blob_sink ) @@ -194,6 +192,8 @@ def delete_azure_data_factory_storage_pipeline() -> None: run_pipeline_wait = AzureDataFactoryRunPipelineOperatorAsync( task_id="run_pipeline_wait", pipeline_name=PIPELINE_NAME, + factory_name=DATAFACTORY_NAME, + resource_group_name=RESOURCE_GROUP_NAME, ) # [END howto_operator_adf_run_pipeline_async] @@ -201,6 +201,8 @@ def delete_azure_data_factory_storage_pipeline() -> None: run_pipeline_no_wait = AzureDataFactoryRunPipelineOperatorAsync( task_id="run_pipeline_no_wait", pipeline_name=PIPELINE_NAME, + factory_name=DATAFACTORY_NAME, + resource_group_name=RESOURCE_GROUP_NAME, wait_for_termination=False, ) # [END howto_operator_adf_run_pipeline] @@ -209,6 +211,8 @@ def delete_azure_data_factory_storage_pipeline() -> None: pipeline_run_sensor_async = AzureDataFactoryPipelineRunStatusSensorAsync( task_id="pipeline_run_sensor_async", run_id=cast(str, XComArg(run_pipeline_wait, key="run_id")), + factory_name=DATAFACTORY_NAME, + resource_group_name=RESOURCE_GROUP_NAME, ) # [END howto_sensor_pipeline_run_sensor_async] diff --git a/astronomer/providers/microsoft/azure/hooks/data_factory.py b/astronomer/providers/microsoft/azure/hooks/data_factory.py index 68151c385..268be4dea 100644 --- a/astronomer/providers/microsoft/azure/hooks/data_factory.py +++ b/astronomer/providers/microsoft/azure/hooks/data_factory.py @@ -75,14 +75,14 @@ class AzureDataFactoryHookAsync(AzureDataFactoryHook): def __init__(self, azure_data_factory_conn_id: str): """Initialize the hook instance.""" - self._async_conn: DataFactoryManagementClient = None + self._async_conn: DataFactoryManagementClient | None = None self.conn_id = azure_data_factory_conn_id super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id) async def get_async_conn(self) -> DataFactoryManagementClient: """Get async connection and connect to azure data factory.""" if self._conn is not None: - return self._conn + return cast(DataFactoryManagementClient, self._conn) # pragma: no cover conn = await sync_to_async(self.get_connection)(self.conn_id) extras = conn.extra_dejson @@ -113,8 +113,8 @@ async def get_async_conn(self) -> DataFactoryManagementClient: async def get_pipeline_run( self, run_id: str, - resource_group_name: str | None = None, - factory_name: str | None = None, + resource_group_name: str, + factory_name: str, **config: Any, ) -> PipelineRun: """ @@ -132,7 +132,7 @@ async def get_pipeline_run( raise AirflowException(e) async def get_adf_pipeline_run_status( - self, run_id: str, resource_group_name: str | None = None, factory_name: str | None = None + self, run_id: str, resource_group_name: str, factory_name: str ) -> str: """ Connect to Azure Data Factory asynchronously and get the pipeline status by run_id. @@ -147,7 +147,7 @@ async def get_adf_pipeline_run_status( factory_name=factory_name, resource_group_name=resource_group_name, ) - status: str = pipeline_run.status + status: str = cast(str, pipeline_run.status) return status except Exception as e: raise AirflowException(e) diff --git a/astronomer/providers/microsoft/azure/operators/data_factory.py b/astronomer/providers/microsoft/azure/operators/data_factory.py index b9a2b6eb8..85bcdf2e1 100644 --- a/astronomer/providers/microsoft/azure/operators/data_factory.py +++ b/astronomer/providers/microsoft/azure/operators/data_factory.py @@ -6,7 +6,6 @@ AzureDataFactoryHook, AzureDataFactoryPipelineRunException, AzureDataFactoryPipelineRunStatus, - PipelineRunInfo, ) from airflow.providers.microsoft.azure.operators.data_factory import ( AzureDataFactoryRunPipelineOperator, @@ -67,12 +66,11 @@ def execute(self, context: Context) -> None: context["ti"].xcom_push(key="run_id", value=run_id) end_time = time.time() + self.timeout - pipeline_run_info = PipelineRunInfo( + pipeline_run_status = hook.get_pipeline_run_status( run_id=run_id, - factory_name=self.factory_name, resource_group_name=self.resource_group_name, + factory_name=self.factory_name, ) - pipeline_run_status = hook.get_pipeline_run_status(**pipeline_run_info) if pipeline_run_status not in AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES: self.defer( timeout=self.execution_timeout, diff --git a/astronomer/providers/microsoft/azure/triggers/data_factory.py b/astronomer/providers/microsoft/azure/triggers/data_factory.py index 509751438..1628ddd6b 100644 --- a/astronomer/providers/microsoft/azure/triggers/data_factory.py +++ b/astronomer/providers/microsoft/azure/triggers/data_factory.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Any, AsyncIterator, Dict, List, Optional, Tuple +from typing import Any, AsyncIterator, Dict, List, Tuple from airflow.providers.microsoft.azure.hooks.data_factory import ( AzureDataFactoryPipelineRunStatus, @@ -29,8 +29,8 @@ def __init__( run_id: str, azure_data_factory_conn_id: str, poke_interval: float, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str, + factory_name: str, ): super().__init__() self.run_id = run_id @@ -108,8 +108,8 @@ def __init__( run_id: str, azure_data_factory_conn_id: str, end_time: float, - resource_group_name: Optional[str] = None, - factory_name: Optional[str] = None, + resource_group_name: str, + factory_name: str, wait_for_termination: bool = True, check_interval: int = 60, ): diff --git a/setup.cfg b/setup.cfg index a6a74edf8..4f5ffc68e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -114,6 +114,7 @@ mypy = types-Markdown types-PyMySQL types-PyYAML + snowflake-connector-python>=3.3.0 # Temporary solution for fixing the issue that pip cannot find proper connector version # All extras from above except 'mypy', 'docs' and 'tests' all = diff --git a/tests/microsoft/azure/operators/test_data_factory.py b/tests/microsoft/azure/operators/test_data_factory.py index 765f9a8d3..992ff2d62 100644 --- a/tests/microsoft/azure/operators/test_data_factory.py +++ b/tests/microsoft/azure/operators/test_data_factory.py @@ -25,6 +25,8 @@ class TestAzureDataFactoryRunPipelineOperatorAsync: task_id="run_pipeline", pipeline_name="pipeline", parameters={"myParam": "value"}, + factory_name="factory_name", + resource_group_name="resource_group", ) @mock.patch(f"{MODULE}.AzureDataFactoryRunPipelineOperatorAsync.defer") diff --git a/tests/microsoft/azure/sensors/test_data_factory.py b/tests/microsoft/azure/sensors/test_data_factory.py index e961854e3..da1810614 100644 --- a/tests/microsoft/azure/sensors/test_data_factory.py +++ b/tests/microsoft/azure/sensors/test_data_factory.py @@ -19,6 +19,8 @@ class TestAzureDataFactoryPipelineRunStatusSensorAsync: SENSOR = AzureDataFactoryPipelineRunStatusSensorAsync( task_id="pipeline_run_sensor_async", run_id=RUN_ID, + factory_name="factory_name", + resource_group_name="resource_group_name", ) @mock.patch(f"{MODULE}.AzureDataFactoryPipelineRunStatusSensorAsync.defer") @@ -61,5 +63,9 @@ def test_poll_interval_deprecation_warning(self): # TODO: Remove once deprecated with pytest.warns(expected_warning=DeprecationWarning): AzureDataFactoryPipelineRunStatusSensorAsync( - task_id="pipeline_run_sensor_async", run_id=self.RUN_ID, poll_interval=5.0 + task_id="pipeline_run_sensor_async", + run_id=self.RUN_ID, + poll_interval=5.0, + factory_name="factory_name", + resource_group_name="resource_group_name", )