Skip to content

Commit

Permalink
Fix azure provider breaking change (#1341)
Browse files Browse the repository at this point in the history
* fix(microsoft): make resource_group_name and factory_name required argument
* fix(microsoft): add missing types for fixing example_adf_run_pipeline.py
    * this is added in newer version of datafactory lib
* fix(azure): remove deprecated pipelineinfo
  • Loading branch information
Lee-W authored Oct 23, 2023
1 parent 5c18bc2 commit 009de79
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ jobs:
sudo apt-get install libsasl2-dev
- run:
name: Install Dependencies
command: pip install -U -e .[all,docs,mypy]
command: pip install -U -e .[docs]
- run:
name: Run Sphinx
command: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
default_args = {
"execution_timeout": timedelta(hours=EXECUTION_TIMEOUT),
"azure_data_factory_conn_id": "azure_data_factory_default",
"factory_name": DATAFACTORY_NAME, # This can also be specified in the ADF connection.
"resource_group_name": RESOURCE_GROUP_NAME, # This can also be specified in the ADF connection.
"retries": int(os.getenv("DEFAULT_TASK_RETRIES", 2)),
"retry_delay": timedelta(seconds=int(os.getenv("DEFAULT_RETRY_DELAY_SECONDS", 60))),
}
Expand Down Expand Up @@ -88,7 +86,7 @@ def create_adf_storage_pipeline() -> None:
df_resource = Factory(location=LOCATION)
df = adf_client.factories.create_or_update(RESOURCE_GROUP_NAME, DATAFACTORY_NAME, df_resource)
while df.provisioning_state != "Succeeded":
df = adf_client.factories.get(RESOURCE_GROUP_NAME, DATAFACTORY_NAME)
df = adf_client.factories.get(RESOURCE_GROUP_NAME, DATAFACTORY_NAME) # type: ignore[assignment]
time.sleep(1)

# Create an Azure Storage linked service
Expand All @@ -97,17 +95,17 @@ def create_adf_storage_pipeline() -> None:
storage_string = SecureString(value=CONNECTION_STRING)

ls_azure_storage = LinkedServiceResource(
properties=AzureStorageLinkedService(connection_string=storage_string)
properties=AzureStorageLinkedService(connection_string=storage_string) # type: ignore[arg-type]
)
adf_client.linked_services.create_or_update(
RESOURCE_GROUP_NAME, DATAFACTORY_NAME, STORAGE_LINKED_SERVICE_NAME, ls_azure_storage
)

# Create an Azure blob dataset (input)
ds_ls = LinkedServiceReference(reference_name=STORAGE_LINKED_SERVICE_NAME)
ds_ls = LinkedServiceReference(type="LinkedServiceReference", reference_name=STORAGE_LINKED_SERVICE_NAME)
ds_azure_blob = DatasetResource(
properties=AzureBlobDataset(
linked_service_name=ds_ls, folder_path=BLOB_PATH, file_name=BLOB_FILE_NAME
linked_service_name=ds_ls, folder_path=BLOB_PATH, file_name=BLOB_FILE_NAME # type: ignore[arg-type]
)
)
adf_client.datasets.create_or_update(
Expand All @@ -116,7 +114,7 @@ def create_adf_storage_pipeline() -> None:

# Create an Azure blob dataset (output)
ds_out_azure_blob = DatasetResource(
properties=AzureBlobDataset(linked_service_name=ds_ls, folder_path=OUTPUT_BLOB_PATH)
properties=AzureBlobDataset(linked_service_name=ds_ls, folder_path=OUTPUT_BLOB_PATH) # type: ignore[arg-type]
)
adf_client.datasets.create_or_update(
RESOURCE_GROUP_NAME, DATAFACTORY_NAME, DATASET_OUTPUT_NAME, ds_out_azure_blob
Expand All @@ -125,8 +123,8 @@ def create_adf_storage_pipeline() -> None:
# Create a copy activity
blob_source = BlobSource()
blob_sink = BlobSink()
ds_in_ref = DatasetReference(reference_name=DATASET_INPUT_NAME)
ds_out_ref = DatasetReference(reference_name=DATASET_OUTPUT_NAME)
ds_in_ref = DatasetReference(type="DatasetReference", reference_name=DATASET_INPUT_NAME)
ds_out_ref = DatasetReference(type="DatasetReference", reference_name=DATASET_OUTPUT_NAME)
copy_activity = CopyActivity(
name=ACTIVITY_NAME, inputs=[ds_in_ref], outputs=[ds_out_ref], source=blob_source, sink=blob_sink
)
Expand Down Expand Up @@ -194,13 +192,17 @@ def delete_azure_data_factory_storage_pipeline() -> None:
run_pipeline_wait = AzureDataFactoryRunPipelineOperatorAsync(
task_id="run_pipeline_wait",
pipeline_name=PIPELINE_NAME,
factory_name=DATAFACTORY_NAME,
resource_group_name=RESOURCE_GROUP_NAME,
)
# [END howto_operator_adf_run_pipeline_async]

# [START howto_operator_adf_run_pipeline]
run_pipeline_no_wait = AzureDataFactoryRunPipelineOperatorAsync(
task_id="run_pipeline_no_wait",
pipeline_name=PIPELINE_NAME,
factory_name=DATAFACTORY_NAME,
resource_group_name=RESOURCE_GROUP_NAME,
wait_for_termination=False,
)
# [END howto_operator_adf_run_pipeline]
Expand All @@ -209,6 +211,8 @@ def delete_azure_data_factory_storage_pipeline() -> None:
pipeline_run_sensor_async = AzureDataFactoryPipelineRunStatusSensorAsync(
task_id="pipeline_run_sensor_async",
run_id=cast(str, XComArg(run_pipeline_wait, key="run_id")),
factory_name=DATAFACTORY_NAME,
resource_group_name=RESOURCE_GROUP_NAME,
)
# [END howto_sensor_pipeline_run_sensor_async]

Expand Down
12 changes: 6 additions & 6 deletions astronomer/providers/microsoft/azure/hooks/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ class AzureDataFactoryHookAsync(AzureDataFactoryHook):

def __init__(self, azure_data_factory_conn_id: str):
"""Initialize the hook instance."""
self._async_conn: DataFactoryManagementClient = None
self._async_conn: DataFactoryManagementClient | None = None
self.conn_id = azure_data_factory_conn_id
super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)

async def get_async_conn(self) -> DataFactoryManagementClient:
"""Get async connection and connect to azure data factory."""
if self._conn is not None:
return self._conn
return cast(DataFactoryManagementClient, self._conn) # pragma: no cover

conn = await sync_to_async(self.get_connection)(self.conn_id)
extras = conn.extra_dejson
Expand Down Expand Up @@ -113,8 +113,8 @@ async def get_async_conn(self) -> DataFactoryManagementClient:
async def get_pipeline_run(
self,
run_id: str,
resource_group_name: str | None = None,
factory_name: str | None = None,
resource_group_name: str,
factory_name: str,
**config: Any,
) -> PipelineRun:
"""
Expand All @@ -132,7 +132,7 @@ async def get_pipeline_run(
raise AirflowException(e)

async def get_adf_pipeline_run_status(
self, run_id: str, resource_group_name: str | None = None, factory_name: str | None = None
self, run_id: str, resource_group_name: str, factory_name: str
) -> str:
"""
Connect to Azure Data Factory asynchronously and get the pipeline status by run_id.
Expand All @@ -147,7 +147,7 @@ async def get_adf_pipeline_run_status(
factory_name=factory_name,
resource_group_name=resource_group_name,
)
status: str = pipeline_run.status
status: str = cast(str, pipeline_run.status)
return status
except Exception as e:
raise AirflowException(e)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
AzureDataFactoryHook,
AzureDataFactoryPipelineRunException,
AzureDataFactoryPipelineRunStatus,
PipelineRunInfo,
)
from airflow.providers.microsoft.azure.operators.data_factory import (
AzureDataFactoryRunPipelineOperator,
Expand Down Expand Up @@ -67,12 +66,11 @@ def execute(self, context: Context) -> None:
context["ti"].xcom_push(key="run_id", value=run_id)
end_time = time.time() + self.timeout

pipeline_run_info = PipelineRunInfo(
pipeline_run_status = hook.get_pipeline_run_status(
run_id=run_id,
factory_name=self.factory_name,
resource_group_name=self.resource_group_name,
factory_name=self.factory_name,
)
pipeline_run_status = hook.get_pipeline_run_status(**pipeline_run_info)
if pipeline_run_status not in AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES:
self.defer(
timeout=self.execution_timeout,
Expand Down
10 changes: 5 additions & 5 deletions astronomer/providers/microsoft/azure/triggers/data_factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import time
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
from typing import Any, AsyncIterator, Dict, List, Tuple

from airflow.providers.microsoft.azure.hooks.data_factory import (
AzureDataFactoryPipelineRunStatus,
Expand Down Expand Up @@ -29,8 +29,8 @@ def __init__(
run_id: str,
azure_data_factory_conn_id: str,
poke_interval: float,
resource_group_name: Optional[str] = None,
factory_name: Optional[str] = None,
resource_group_name: str,
factory_name: str,
):
super().__init__()
self.run_id = run_id
Expand Down Expand Up @@ -108,8 +108,8 @@ def __init__(
run_id: str,
azure_data_factory_conn_id: str,
end_time: float,
resource_group_name: Optional[str] = None,
factory_name: Optional[str] = None,
resource_group_name: str,
factory_name: str,
wait_for_termination: bool = True,
check_interval: int = 60,
):
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ mypy =
types-Markdown
types-PyMySQL
types-PyYAML
snowflake-connector-python>=3.3.0 # Temporary solution for fixing the issue that pip cannot find proper connector version

# All extras from above except 'mypy', 'docs' and 'tests'
all =
Expand Down
2 changes: 2 additions & 0 deletions tests/microsoft/azure/operators/test_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class TestAzureDataFactoryRunPipelineOperatorAsync:
task_id="run_pipeline",
pipeline_name="pipeline",
parameters={"myParam": "value"},
factory_name="factory_name",
resource_group_name="resource_group",
)

@mock.patch(f"{MODULE}.AzureDataFactoryRunPipelineOperatorAsync.defer")
Expand Down
8 changes: 7 additions & 1 deletion tests/microsoft/azure/sensors/test_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class TestAzureDataFactoryPipelineRunStatusSensorAsync:
SENSOR = AzureDataFactoryPipelineRunStatusSensorAsync(
task_id="pipeline_run_sensor_async",
run_id=RUN_ID,
factory_name="factory_name",
resource_group_name="resource_group_name",
)

@mock.patch(f"{MODULE}.AzureDataFactoryPipelineRunStatusSensorAsync.defer")
Expand Down Expand Up @@ -61,5 +63,9 @@ def test_poll_interval_deprecation_warning(self):
# TODO: Remove once deprecated
with pytest.warns(expected_warning=DeprecationWarning):
AzureDataFactoryPipelineRunStatusSensorAsync(
task_id="pipeline_run_sensor_async", run_id=self.RUN_ID, poll_interval=5.0
task_id="pipeline_run_sensor_async",
run_id=self.RUN_ID,
poll_interval=5.0,
factory_name="factory_name",
resource_group_name="resource_group_name",
)

0 comments on commit 009de79

Please sign in to comment.