Test Airflow DAG locally

Installing the python libraries

Airflow DAGs can be tested and integrated within your unittest workflow.

For that, apache-airflow and pytest are all the Python pip libraries you need.

First, import the libraries and retrieve the current working directory:

from pathlib import Path
from airflow.models import DagBag
from unittest.mock import patch
import pytest

SCRIPT_DIRECTORY = Path(__file__).parent

Collecting the DAGs in the DagBag

Second, you want to collect all the local dags you have under your dags/ folder and want to test. You can use airflow.models.DagBag. You can create a dedicated dag_bag function for that task:

@pytest.fixture()
def dag_bag() -> DagBag:
    dag_folder = SCRIPT_DIRECTORY / ".." / "dags"
    dag_bag = DagBag(
        dag_folder=dag_folder,
        read_dags_from_db=False,
    )
    return dag_bag

This function will return a collection of dags, parsed out from the local dag folder tree you have specified.

Note: this above function is tailored for a project with a similar structure:

airflow-dag-repo
├── dags # all your dags go there
    └── dag.py
├── airflow_dag_repo
    ├── __init__.py
    └── commons.py
├── tests
    └── test_dag.py
├── poetry.lock
└── pyproject.toml

Optional: I use poetry as Python package manager, you can learn more about it too here.

Note: the fixture decorator is used as a setup tool to initialize reusable objects at one place and pass them to all your test functions as arguments. Here, the dag_bag object can now be accessed by all the test functions in that module.

Running the test suit on the collected DAGs

Finally, you can implement your tests:

def test_dag_tasks_count(dag_bag):
    dag = dag_bag.get_dag(dag_id="your-dag-id")
    assert dag.task_count == 4

def test_dags_import_errors(dag_bag):
    assert dag_bag.import_errors == {}

You can check the full example on Github: airflow-dag-unittests

Note: you can wrap-up your test functions within a Class using unittest.TestCase as I did on the codebase on github.com/olivierbenard/airflow-dag-unittests. However, doing so prevents you from using fixtures. A work-around exists, I will let you check what I did.

Mocking Airflow Variables

If you are using Airflow Variables in your DAGs e.g.:

from airflow.models import Variable
MY_VARIABLE = Variable.get("my-variable")

You need to add the following lines:

@patch.dict(
    "os.environ",
    AIRFLOW_VAR_YOUR_VARIABLE="", # mock your variable, prefixed with AIRFLOW_VAR.
)
@pytest.fixture()
def dag_bag() -> DagBag:
    ...

Otherwise, you will stumble across the following error during your local tests:

raise KeyError(f"Variable {key} does not exist")
KeyError: 'Variable <your-variable> does not exist'

To conclude, Airflow DAGs are always a headache to test and integrate within your unittest workflow. I hope this makes it easier.

Upload Airflow DAG to DAGs folder (hosted on Google Cloud Composer)

One you have your DAG defined under a nice dag.py file, you need to transfer it from your local environment to the DAGs/ folder of your Airflow instance.

airflow-dag-repo
├── dags
    └── dag.py
├── airflow_dag_repo
    ├── __init__.py
    └── utils.py
├── poetry.lock
└── pyproject.toml

This can be done manually via the Command Line Interface, replacing the Gitlab $airflow_gcs_bucket_dev environment variable by the storage location your running Airflow instance relies on:

gsutil cp -r ./dags* gs://$airflow_gcs_bucket_dev/dags/airflow_dag_repo/

In our case, we consider our Airflow instance spinning on the Google Cloud Composer Service from Google Cloud Platform and using Cloud Storage as Storage Service. Thus the use of the gsutil that lets us access Cloud Storage (where our DAGs/ folder is located).

The second best option would be to integrate this command execution within your Gitlab CI/CD pipeline execution.

However, the most advanced and preferred option (like, the ultimate no-brainer) would be to use Terraform.

Run the extra mile

You want to reduce the workload that the Airflow Parser has to deal with. Thus, to improve the parsing time, send zip compressed files and remove unnecessary noises e.g. pruning away cache folders:

zip -r airflow_dag_repo.zip ./dags ./airflow_dag_repo
gsutil cp -r airflow_dag_repo.zip gs://$airflow_gcs_bucket_dev/dags/
locals {
    dags = fileset("${var.root_dir}/../../dags/", "**")
    code = toset([
        for file in fileset("${var.root_dir}/../../airflow_dag_repo/", "**"):
            file if length(regexall(".*__pycache__.*", file)) == 0
    ])
    upload_folder = replace(var.service_name, "-", "_")
}

resource "google_storage_bucket_object" "dags" {
    for_each = local.dags
    name   = "dags/${local.upload_folder}/${each.key}"
    source = "${var.root_dir}/../../dags/${each.key}"
    bucket = data.google_secret_manager_secret_version.airflow_bucket.secret_data
}

resource "google_storage_bucket_object" "code" {
    for_each = local.code
    name   = "dags/${local.upload_folder}/${each.key}"
    source = "${var.root_dir}/../../airflow_dag_repo/${each.key}"
    bucket = data.google_secret_manager_secret_version.airflow_bucket.secret_data
}

Note: the service_name is inferred by terragrunt

locals {
    git_remote_origin_url = run_cmd("--terragrunt-quiet", "git", "config", "--get", "remote.origin.url")
    service_name = run_cmd("--terragrunt-quiet", "basename", "-s", ".git", local.git_remote_origin_url)
}
airflow-dag-repo
├── dags
    └── dag.py
├── airflow_dag_repo
    ├── __init__.py
    └── utils.py
├── terraform
    ├── sync_dags.tf
    ├── tf_variables.tf
    └── versions.tf
├── terragrunt
    ├── dev
        ├── env.hcl
        ├── env.tfvars
        └── terragrunt.hcl
    ├── prod
        ├── env.hcl
        ├── env.tfvars
        └── terragrunt.hcl
    └── terragrunt.hcl
├── poetry.lock
└── pyproject.toml

Fix Airflow Pylance reportMissingImports

The Airflow Pylance reportMissingImports can be fixed by installing the apache-airflow python library within your virtual environment and then enter the path where Visual Studio Code (or any other IDE) can find it.

If you are using poetry as python package manager (more on poetry here) and Visual Studio Code (VSC):

  1. Run

    poetry add apache-airflow
    
  2. Find the path where the poetry virtual environment is located on your system:

    poetry env info --path
    
  3. On VSC, open the Command Palette, then >Python: Select Interpreter and enter the path returned in the above command.

After this is done, you will not have those error anymore:

Import "airflow" could not be resolved Pylance reportMissingImports

Note: this method remains valid for all report Missing Imports errors you might encounter.

Airflow DAG labels

On Airflow you can label dependencies between the tasks using the Label() object.

from airflow.utils.edgemodifier import Label

task_A >> Label("Transform") >> task_B

This become handy when you explicitly want to document the edges of your DAG (Direct Acyclic Graph) e.g. to indicate what is happening between the tasks.

Minimal Functional Example

from airflow import DAG
from airflow.decorators import task
from airflow.utils.edgemodifier import Label

from datetime import datetime

with DAG(
    dag_id="test-airflow-dag-label",
    start_date=datetime(2023, 1, 23),
    schedule_interval="@daily",
    catchup=False,
    tags=["test"],
) as dag:

    @task
    def task_A():
        ...

    @task
    def task_B():
        ...

    task_A() >> Label("Move data from A to B") >> task_B()

Note: the ... (ellipsis literal) is equivalent to pass.

How to create custom Airflow Operators

Disclaimer: This article is intended for users with already some hands-on experience with Airflow. If this is not the case, I am working on a Airflow Essentials Survival Kit guide to be released. The link will be posted here as soon as it is the case.

To create custom Airflow Operators, all you need is to import the Airflow BaseOperator and surcharge it with the parameters and logic you need. You just have to fill-in the template below:

from airflow.models.baseoperator import BaseOperator
from airflow.utils.decorators import apply_defaults

class MyCustomOperator(BaseOperator):

    # add the templated param(s) your custom operator takes:
    template_fields = ("your_custom_templated_field", )

    @apply_defaults
    def __init__(
        self,
        your_custom_field,
        your_custom_templated_field,
        *args,
        **kwargs
    ):
    super().__init__(*args, **kwargs)

    # assign the normal and templated params:
    self.your_custom_field = your_custom_field
    self.your_custom_templated_field = your_custom_templated_field

    def execute(self, context):
        # the logic to perform
        # ...

Note: The execute() method and context argument are mandatory.

Then, assuming you store this module in another file, you can call it inside your DAG file:

from path.to.file import MyCustomOperator

my_custom_job = MyCustomOperator(
    task_id="my_custom_operator",
    your_custom_templated_field=f'E.g. {"{{ds}}"}',
    your_custom_field=42,
)

Notes:

  • Because you have subscribed to the template_fields option, your custom_templated_field accepts Jinja Templated Variables like {{ds}}. You do not necessarily need to subscribe to this option though.
  • You can have more than one custom field.
  • Not all job’s parameters accept Jinja Templated values. Look up in the documentation which are the accepted templated ones. E.g. for BigQueryInsertJobOperator.

Example: Ingesting Data From API

Context: you have data sitting on an API. You want to fetch data from this API and ingest it into Google Cloud BigQuery. In-between, you are storing the fetched raw data as temporary json files inside a Google Cloud Storage Bucket. Next step is then to flush the files’ content within a BigQuery table.

In order to do so, you need to create a custom Airflow Operator that can use your API client to fetch the data; retrieving the credentials from an Airflow Connection – and stores the retrieved data into a temporary json file on Google Cloud Storage.

Your custom made Airflow Operator, stored in an api_ingest/core.py alike file will look like:

from typing import Any
from airflow.models.baseoperator import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.hooks.base_hook import BaseHook
from custom_api_client.client import (
    CustomAPIClient,
)

class APIToGCSOperator(
    BaseOperator
): # pylint: disable=too-few-public-methods
    """
    Custom-made Airflow Operator fetching data from the API.
    """

    template_fields = (
        "destination_file",
        "ingested_at"
    )

    @apply_default
    def __init__(
        self,
        destination_file: str,
        ingested_at: str,
        entity: str,
        *args: Any,
        **kwargs: Any
    ) -> None:
        super().__init__(*args, **kwargs)

        self.destination_file = destination_file
        self.ingested_at = ingested_at
        self.entity = entity

    def execute(
        self,
        context: Any
    ) -> None # pylint: disable=unused-argument
        """
        Fetches data from the API and writes into GCS bucket.

        1. Reads the credentials stored into the Airflow Connection.
        2. Instantiates the API client.
        3. Fetches the data with the given parameters.
        4. Flushes the result into a temporary GCS bucket.
        """

        api_connection = BaseHook.get_connections("your_connection_name")
        credentials = api_connection.extra_dejson
        client_id = credentials["client_id"]
        client_secret = credentials["client_secret"]
        refresh_token = credentials["non_expiring_refresh_token"]

        custom_api_client = CustomAPIClient(
            client_id = client_id,
            client_secret = client_secret,
            refresh_token = refresh_token,
        )

        with open(
            self.destination_file,
            "w",
            encoding="utf-8"
        ) as output:

            fetched_data_json = custom_api_client.fetch_entity(
                entity=self.entity
            )

            entity_json = dict(
                content = json.dumps(
                    fetched_data_json,
                    default=str,
                    ensure_ascii=False
                    ),
                ingested_at = self.ingested_at,
            )

            json.dump(
                entity_json,
                output,
                default=str,
                ensure_ascii=False
            )
            output.write("\n")

Note: You can create your own custom-made API clients. To make sure yours is available inside your Airflow DAG, make sure to upload the package into the plugins Airflow folder beforehand.

In the main DAG’s module, inside the DAG’s with context manager, it will look like:

from airflow.operators.bash import BashOperator
from api_ingest.core import APIToGCSOperator

staging_dir = "{{var.value.local_storage_dir}}" + "/tempfiles/api_ingest"

create_staging_dir = BashOperator(
    task_id=f"create-staging-dir",
    bash_command=f"mkdir -p {staging_dir}"
)

cleanup_staging_dir = BashOperator(
    task_id=f"cleanup-staging-dir",
    bash_command=f"rm -rf {staging_dir}"
)

api_to_gcs = APIToGCSOperator(
    task_id = "api-to-gcs",
    destination_file = f'{staging_dir}/data_{"{{ds_nodash}}"}.mjson',
    ingested_at = "{{ts}}",
    entity = "campaigns",
)

Note: It is more efficient to use {{var.value.your_variable}} instead of Variable.get("your_variable"). Downside are: the real value is only gonna be replaced at execution time and only for the fields accepting Jinja Templated variables.

And here we go, you should now be able to create your own custom Airflow Operators! Have fun crafting AF; tuning it into your likings. 💫