The below code works but my requirement is to pass totalbuckets as an input to the function as opposed to global variable. I am having trouble passing it as a variable and do xcom_pull in next task. This dag basically creates buckets based on the number of inputs and totalbuckets is a constant. Appreciate your help in advance.
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
with DAG('test-live', catchup=False, schedule_interval=None, default_args=args) as test_live:
totalbuckets = 3
# branches based on number of buckets
def branch_buckets(**context):
buckets = defaultdict(list)
for i in range(len(inputs_to_process)):
buckets[f'bucket_{(1 i % totalbuckets)}'].append(inputs_to_process[i])
for bucket_name, input_sublist in buckets.items():
context['ti'].xcom_push(key = bucket_name, value = input_sublist)
return list(buckets.keys())
# BranchPythonOperator will launch the buckets and distributes inputs among the buckets
branch_buckets = BranchPythonOperator(
task_id='branch_buckets',
python_callable=branch_buckets,
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True,
dag=test_live
)
# update provider tables with merge sql
def update_inputs(sf_conn_id, bucket_name, **context):
input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
print(f"Processing inputs {input_sublist} in {bucket_name}")
from custom.hooks.snowflake_hook import SnowflakeHook
for p in input_sublist:
merge_sql=f"""
merge into ......"""
bucket_tasks = []
for i in range(totalbuckets):
task= PythonOperator(
task_id=f'bucket_{i 1}',
python_callable=update_inputs,
provide_context=True,
op_kwargs={'bucket_name':f'bucket_{i 1}','sf_conn_id': SF_CONN_ID},
dag=test_live
)
bucket_tasks.append(task)
CodePudding user response:
@hussein awala I am doing something like below but cannot parse totalbuckets in bucket_tasks
from airflow.operators.python import PythonOperator, BranchPythonOperator
with DAG('test-live', catchup=False, schedule_interval=None, default_args=args) as test_live:
#totalbuckets = 3
def branch_buckets(totalbuckets, **context):
buckets = defaultdict(list)
for i in range(len(inputs_to_process)):
buckets[f'bucket_{(1 i % totalbuckets)}'].append(inputs_to_process[i])
for bucket_name, input_sublist in buckets.items():
context['ti'].xcom_push(key = bucket_name, value = input_sublist)
return list(buckets.keys())
# BranchPythonOperator will launch the buckets and distributes inputs among the buckets
branch_buckets = BranchPythonOperator(
task_id='branch_buckets',
python_callable=branch_buckets,
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True, op_kwargs={'totalbuckets':3},
dag=test_live
)
# update provider tables with merge sql
def update_inputs(sf_conn_id, bucket_name, **context):
input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
print(f"Processing inputs {input_sublist} in {bucket_name}")
from custom.hooks.snowflake_hook import SnowflakeHook
for p in input_sublist:
merge_sql=f"""
merge into ......"""
bucket_tasks = []
for i in range(totalbuckets):
task= PythonOperator(
task_id=f'bucket_{i 1}',
python_callable=update_inputs,
provide_context=True,
op_kwargs={'bucket_name':f'bucket_{i 1}','sf_conn_id': SF_CONN_ID},
dag=test_live
)
bucket_tasks.append(task)```
CodePudding user response:
If totalbuckets
is different from run to other, it should be a run conf variable, you can provide it for each run crated from the UI, CLI, Airflow REST API or even python API.
from airflow import DAG
from airflow.operators.python import PythonOperator, BranchPythonOperator
from airflow.models.param import Param
with DAG(
'test-live',
catchup=False,
schedule_interval=None,
default_args=args,
params={"totalbuckets": Param(default=3, type="integer")},
) as test_live:
# branches based on number of buckets
def branch_buckets(**context):
buckets = defaultdict(list)
for i in range(len(inputs_to_process)):
buckets[f'bucket_{(1 i % int("{{ params.totalbuckets }}"))}'].append(inputs_to_process[i])
for bucket_name, input_sublist in buckets.items():
context['ti'].xcom_push(key = bucket_name, value = input_sublist)
return list(buckets.keys())
# BranchPythonOperator will launch the buckets and distributes inputs among the buckets
branch_buckets = BranchPythonOperator(
task_id='branch_buckets',
python_callable=branch_buckets,
trigger_rule=TriggerRule.NONE_FAILED,
provide_context=True,
dag=test_live
)
# update provider tables with merge sql
def update_inputs(sf_conn_id, bucket_name, **context):
input_sublist = context['ti'].xcom_pull(task_ids='branch_buckets', key=bucket_name)
print(f"Processing inputs {input_sublist} in {bucket_name}")
from custom.hooks.snowflake_hook import SnowflakeHook
for p in input_sublist:
merge_sql=f"""
merge into ......"""
bucket_tasks = []
for i in range(int("{{ params.totalbuckets }}")):
task= PythonOperator(
task_id=f'bucket_{i 1}',
python_callable=update_inputs,
provide_context=True,
op_kwargs={'bucket_name':f'bucket_{i 1}','sf_conn_id': SF_CONN_ID},
dag=test_live
)
bucket_tasks.append(task)
Example to run it:
airflow dags trigger --conf '{"totalbuckets": 10}' test-live
Or via the UI.
update:
And if it's static, but different from an environment to other, it can be an Airflow variable, and read it directly in the tasks using jinja to avoid reading it at each Dag Files processing.
But if it's completely static, the most recommended solution is using python variable as you do, because to read dag run conf and Airflow variables, the task/dag send a query to the database.