diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index f88139ff0..ff4926b40 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -1,6 +1,7 @@ import importlib import logging from typing import Optional +import inspect from airflow.models import BaseOperator from airflow.models.dag import DAG @@ -24,14 +25,25 @@ def get_airflow_task(task: Task, dag: DAG, task_group: Optional[TaskGroup] = Non # fully qualified name defined in the task module_name, class_name = task.operator_class.rsplit(".", 1) module = importlib.import_module(module_name) - Operator = getattr(module, class_name) - - airflow_task = Operator( - task_id=task.id, - dag=dag, - task_group=task_group, + operator = getattr(module, class_name) + + # ensure we only pass the arguments that the operator expects + supported_args = set() + for inherited_class in operator.mro(): + for arg in inspect.signature(inherited_class.__init__).parameters: + supported_args.add(arg) + + potential_operator_args = { + "task_id": task.id, + "dag": dag, + "task_group": task_group, **task.arguments, - ) + } + operator_args = { + arg_key: arg_value for arg_key, arg_value in potential_operator_args.items() if arg_key in supported_args + } + + airflow_task = operator(**operator_args) if not isinstance(airflow_task, BaseOperator): raise TypeError(f"Operator class {task.operator_class} is not a subclass of BaseOperator") diff --git a/dev/dags/basic_cosmos_task_group.py b/dev/dags/basic_cosmos_task_group.py index 48e31d4da..3835837cf 100644 --- a/dev/dags/basic_cosmos_task_group.py +++ b/dev/dags/basic_cosmos_task_group.py @@ -38,6 +38,9 @@ def basic_cosmos_task_group() -> None: render_config=RenderConfig( test_behavior="after_all", ), + operator_args={ + "on_warning_callback": print, + }, ) post_dbt = EmptyOperator(task_id="post_dbt")