Passing task outputs with AirFlow XCOM - python

I have an SSHOperator that writes a filepath to stdout. I'd like to get the os.path.basename of that filepath so that I can pass it as a parameter to my next task (which is an sftp pull). The idea is to download a remote file into the current working directory. This is what I have so far:
with DAG('my_dag',
default_args = dict(...
xcom_push = True,
)
) as dag:
# there is a get_update_id task here, which has been snipped for brevity
get_results = SSHOperator(task_id = 'get_results',
ssh_conn_id = 'my_remote_server',
command = """cd ~/path/to/dir && python results.py -t p -u {{ task_instance.xcom_pull(task_ids='get_update_id') }}""",
cmd_timeout = -1,
)
download_results = SFTPOperator(task_id = 'download_results',
ssh_conn_id = 'my_remote_server',
remote_filepath = base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }}"""),
local_filepath = os.path.basename(base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }""").decode()),
operation = 'get',
)
Airflow tells me there's an error on the remote_filepath = line. Investigating this further, I see that the value passed to base64.b64decode is not the xcom value from the get_results task, but is rather the raw string starting with {{.
My feeling is that since tasks are templated, there's some under-the-hood magic to resolve the templated string. Whereas this is not exactly supported by os.path.basename. So would I need to create an intermediate task to get the basename? Is there no way to shorthand this the way I've tried?
I'd appreciate any help on this

You want to decode the XCOM return value when Airflow renders the remote_filepath property for the Task instance.
This means that the b64decode function must be invoked within the template string.
There is a catch though, we have to make this function available in the template context by providing it as a parameter or on the DAG level as a user defined filter or macro.
def basename_b64decode(value):
return os.path.basename(base64.b64decode(value)).decode()
download_results = SFTPOperator(
task_id = 'download_results',
ssh_conn_id = 'my_remote_server',
remote_filepath = """{{params.b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
local_filepath = """{{params.basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
operation = 'get',
params = {
'b64decode': base64.b64decode
'basename_b64decode': basename_b64decode
}
)
For the DAG user-defined macro approach, you can write:
with DAG('my_dag',
default_args = dict(...
xcom_push = True,
user_defined_macros=dict(
basename_b64decode=basename_b64decode,
b64decode=base64.b64decode
)
)
) as dag:
download_results = SFTPOperator(
task_id = 'download_results',
ssh_conn_id = 'my_remote_server',
remote_filepath = """{{b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
local_filepath = """{{basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
operation = 'get',
)

Related

Calling TaskGroup with Dynamic sub task id from BranchPythonOperator

I want to call a TaskGroup with a Dynamic sub-task id from BranchPythonOperator.
This is the DAG flow that I have:
branch_dag
My case is I want to check whether a table exists in BigQuery or not.
If exists: do nothing and end the DAG
If not exists: Ingest the data from Postgres to Google Cloud Storage
I know that to call a TaskGroup from BranchPythonOperator is by calling the task id with following format:
group_task_id.task_id
The problem is, my task group's sub task id is dynamic, depends on how many time I loop the TaskGroup. So the sub_task will be:
parent_task_id.sub_task_1
parent_task_id.sub_task_2
parent_task_id.sub_task_3
...
parent_task_id.sub_task_x
This is the following code for the DAG that I have:
import airflow
from airflow.providers.google.cloud.transfers.postgres_to_gcs import PostgresToGCSOperator
from airflow.utils.task_group import TaskGroup
from google.cloud.exceptions import NotFound
from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.operators.dummy import DummyOperator
from google.cloud import bigquery
default_args = {
'owner': 'Airflow',
'start_date': airflow.utils.dates.days_ago(2),
}
with DAG(dag_id='branch_dag', default_args=default_args, schedule_interval=None) as dag:
def create_task_group(worker=1):
var = dict()
with TaskGroup(group_id='parent_task_id') as tg1:
for i in range(worker):
var[f'sub_task_{i}'] = PostgresToGCSOperator(
task_id = f'sub_task_{i}',
postgres_conn_id = 'some_postgres_conn_id',
sql = 'test.sql',
bucket = 'test_bucket',
filename = 'test_file.json',
export_format = 'json',
gzip = True,
params = {
'worker': worker
}
)
return tg1
def is_exists_table():
client = bigquery.Client()
try:
table_name = client.get_table('dataset_id.some_table')
if table_name:
return 'task_end'
except NotFound as error:
return 'parent_task_id'
task_start = DummyOperator(
task_id = 'start'
)
task_branch_table = BranchPythonOperator(
task_id ='check_table_exists_in_bigquery',
python_callable = is_exists_table
)
task_pg_to_gcs_init = create_task_group(worker=3)
task_end = DummyOperator(
task_id = 'end',
trigger_rule = 'all_done'
)
task_start >> task_branch_table >> task_end
task_start >> task_branch_table >> task_pg_to_gcs_init >> task_end
When I run the dag, it returns
**airflow.exceptions.TaskNotFound: Task parent_task_id not found **
But this is expected, what I don't know is how to iterate the parent_task_id.sub_task_x on is_exists_table function. Or are there any workaround?
This is the test.sql file if it's needed
SELECT
id,
name,
country
FROM some_table
WHERE 1=1
AND ABS(MOD(hashtext(id::TEXT), 3)) = {{params.worker}};
-- returns 1M+ rows
I already seen this question as reference Question but I think my case is more specific.
When designing your data pipelines, you may encounter use cases that require more complex task flows than "Task A > Task B > Task C." For example, you may have a use case where you need to decide between multiple tasks to execute based on the results of an upstream task. Or you may have a case where part of your pipeline should only run under certain external conditions. Fortunately, Airflow has multiple options for building conditional logic and/or branching into your DAGs.
I found a dirty way around it.
What I did is creating 1 additional task using DummyOperator called task_pass.
task_pass = DummyOperator(
task_id = 'pass_to_task_group'
)
So the DAG flow now looks like this:
task_start >> task_branch_table >> task_end
task_start >> task_branch_table >> task_pass >> task_pg_to_gcs_init >> task_end
Also there is 1 mistake that I made from the code above, notice that the params I set was worker. This is wrong because worker is the constant while the thing that I need to iterate is the i variable. So I change it from:
params: worker
to:
params: i

Using xcom pushed value as parameter in BigQueryToGCSOperator

I need to count the number of rows in a table and use the row count in the filename of an export to GCS. The following is an excerpt from my DAG.
with models.DAG(
'my_dag',
schedule_interval = '0 6 * * 1',
start_date = datetime(2022, 1, 1),
catchup = False
) as dag:
# create segment filterd views to output CSV to GCS
def prepareSegmentTables(segment, **kwargs):
segment_table_queries = f"""
TRUNCATE TABLE dataset.some_table;
INSERT INTO dataset.some_table (column1)
SELECT DISTINCT column1
FROM dataset.some_other_table
WHERE column2 = '{ segment['id'] }';
"""
# execute query
client.query(segment_table_queries).result()
# store the row counts of each type
kwargs['ti'].xcom_push(
key = "ROW_COUNTS",
value = {
"column1": getTableRowCount("dataset.some_table"),
}
)
def get_row_counts(segment, **kwargs):
ROW_COUNTS = kwargs['ti'].xcom_pull(
key = "ROW_COUNTS",
task_ids = [ f"prepare_segment_tables" ]
)
#tasks
prepare_segment_tables = PythonOperator(
task_id = f"prepare_segment_tables",
python_callable = prepareSegmentTables,
op_kwargs = { "segment": segment },
dag = dag
)
export_to_gcs = BigQueryToCloudStorageOperator(
task_id = f"gcs_lr_to_li_auid_{segment['id']}",
source_project_dataset_table = f"{GCP_PROJECT}.{DATASET_NAME}.some_table",
destination_cloud_storage_uris = f"gs://{GCS_BUCKET}/{FILENAME_PATH}{segment['name']}_"
+ str( ti.xcom_pull(key = "ROW_COUNTS", task_ids = [ f"prepare_segment_tables" ])[0].column1 )
+ f"_{TODAY_STR}.csv",
# this works though
# destination_cloud_storage_uris = f"gs://{GCS_BUCKET}/{FILENAME_PATH}{segment['name']}_" + str( getTableRowCount("dataset.some_table") ) + f"_{TODAY_STR}.csv",
compression = 'NONE', export_format = 'CSV', field_delimiter = ',', print_header = True
)
prepare_segment_tables >> export_to_gcs
As can be seen, I am pushing ROW_COUNTS into xcom while calling prepareSegmentTables via a PythonOperator. When I do xcom_pull inside another PythonOperator, calling get_row_counts, it properly pulls the value, but when I pass the same syntax as a parameter to BigQueryToCloudStorageOperator or BigQueryToGCSOperator, it throws an error.
It says ti or kwargs['ti'], depending on what I use is undefined. Some people suggest using double {{ }}, and even that didn't work for me.
For now, I have resorted to calling getTableRowCount() directly, in the parameter instead of first storing it in a variable. It works, but I use the filename downstream at least one more time, and this approach results in unnecessarily querying the table for a row count multiple times.
Any help getting xcom to work or to figure out a way to get row count in the filename efficiently is appreciated.
Airflow is a distributed system. It is important to note that your DAG code isn't executed all in the same context.
The DAG is parsed and assembled on a schedule. The Task is executed on a worker. XCOM is available inside the worker context.
As you saw - ti or kwargs['ti'] would allow you to access XCOMs from inside a PythonOperator (specifically the python_callable) but in your BigQueryToCloudStorageOperator, you don't have task_instance available, as you aren't in that context.
You can use Jinja Templating to defer fetching the XCOM until you are in the correct context with the worker (read more here https://airflow.apache.org/docs/apache-airflow/stable/templates-ref.html )
You probably need something like:
BigQueryToCloudStorageOperator(
...
destination_cloud_storage_uris = f"gs://{GCS_BUCKET}/{FILENAME_PATH}{segment['name']}_" + "{{ ti.xcom_pull(key = "ROW_COUNTS", task_ids=["prepare_segment_tables"])[0].column1 ) }}" + f"_{TODAY_STR}.csv",
...
)
Note that I am being careful to not mix f-strings and jinja there, as they both utilize the {} syntax and don't play well together.
You will also probably need to ensure that your XCOM will parse as you expect in Jinja, and may want to return your XCOM differently or parse it in advanced with a PythonOperator
See here for help with what you can do within a Jinja template: https://jinja.palletsprojects.com/
Note also that this works only because "destination_cloud_storage_uris" is a templated field on this operator - and not all fields are.
https://airflow.apache.org/docs/apache-airflow/1.10.15/_modules/airflow/contrib/operators/bigquery_to_gcs.html#BigQueryToCloudStorageOperator

Read run_id in airflow operator for non templated fields

I am trying to read run_id for DAG in SnowflakeOperator to set a session parameter, query_tag. But it seems like the session parameter is not templated.
snowflake_operator = SnowflakeOperator(
task_id='snowflake_task',
snowflake_conn_id='snowflake_conn',
sql='resources/some.sql',
warehouse='my_warehouse',
database='my_database',
role='my_role',
session_parameters={
'QUERY_TAG': 'dagrun_{{run_id}}'
}
)
How can I reference run_id and use it as an input here?
You need to make the non-templated field templated.
class MySnowflakeOperator(SnowflakeOperator):
template_fields = (
"session_parameters",
) + SnowflakeOperator.template_fields
Then you can use it as:
snowflake_operator = MySnowflakeOperator(
task_id='snowflake_task',
snowflake_conn_id='snowflake_conn',
sql='resources/some.sql',
warehouse='my_warehouse',
database='my_database',
role='my_role',
session_parameters={
'QUERY_TAG': 'dagrun_{{run_id}}'
}
)

How to pass a variable task_ids to a xcom.pull in Airflow?

Sometimes I find it handy to create tasks using a loop.
Below is an example of a SqoopOperator of which I use the xcom value from the previous PythonOperator in the where clause. I am trying to use a variable get_delivery_sqn_task_id to access the correct xcom value ti.xcom_pull(task_ids=get_delivery_sqn_task_id , however this does not work (returns ()).
I can take everything out of the loop, but this makes the code quite ugly I think. Is there an elegant solution to have a variable task_ids to retrieve xcom values? I guess otherwise the best solution is using the Airflow Variables.
for table in tables:
get_delivery_sqn_task_id ='get_delivery_sqn_'+ table
get_delivery_sqn_task = PythonOperator(
task_id = get_delivery_sqn_task_id,
python_callable = get_delivery_sqn,
op_kwargs = {
'table_name': table
},
provide_context = True,
dag = dag
)
sqoop_operator_task = SqoopOperator(
task_id = "sqoop_"+table,
conn_id = "DWDH_PROD",
table = table,
cmd_type = "import",
target_dir = "/sourcedata/sqoop_tmp/"+table,
num_mappers = 1,
where = "delivery_sqn > {{ ti.xcom_pull(task_ids=get_delivery_sqn_task_id, key='return_value') }}",
dag = dag
)
You can do:
"delivery_sqn > {{{{ ti.xcom_pull(task_ids={}, key='return_value') }}}}".format(get_delivery_sqn_task_id)

How to pass parameters to Airflow on_success_callback and on_failure_callback

I have implemented email alerts on success and failure using on_success_callback and on_failure_callback.
According to Airflow documentation,
a context dictionary is passed as a single parameter to this function.
How can I pass another parameter to these callback methods?
Here is my code
from airflow.utils.email import send_email_smtp
def task_success_alert(context):
subject = "[Airflow] DAG {0} - Task {1}: Success".format(
context['task_instance_key_str'].split('__')[0],
context['task_instance_key_str'].split('__')[1]
)
html_content = """
DAG: {0}<br>
Task: {1}<br>
Succeeded on: {2}
""".format(
context['task_instance_key_str'].split('__')[0],
context['task_instance_key_str'].split('__')[1],
datetime.now()
)
send_email_smtp(dag_vars["dev_mailing_list"], subject, html_content)
def task_failure_alert(context):
subject = "[Airflow] DAG {0} - Task {1}: Failed".format(
context['task_instance_key_str'].split('__')[0],
context['task_instance_key_str'].split('__')[1]
)
html_content = """
DAG: {0}<br>
Task: {1}<br>
Failed on: {2}
""".format(
context['task_instance_key_str'].split('__')[0],
context['task_instance_key_str'].split('__')[1],
datetime.now()
)
send_email_smtp(dag_vars["dev_mailing_list"], subject, html_content)
default_args = {
'owner': 'airflow',
'depends_on_past': False,
'start_date': datetime(2019, 6, 13),
'on_success_callback': task_success_alert,
'on_failure_callback': task_failure_alert
}
I intend to move the callbacks to another package and pass the email address as parameter.
You could define a function inside your dag that calls the function from your package. And while calling that function, pass email as an argument. You can refine it further at your DAG level to pass only information required for the emails.
from package import outer_task_success_callback
email = 'xyz#example.com'
def task_success_alert(context):
dag_id = context['dag'].dag_id
task_id = context['task_instance']. task_id
outer_task_success_callback(dag_id, task_id, email)
default_args = {
'owner': 'airflow',
'depends_on_past': False,
'start_date': datetime(2019, 6, 13),
'on_success_callback': task_success_alert,
'on_failure_callback': task_failure_alert
}
This will allow you to customize before you call the function in your package.
On a side note, airflow has smtp email functionality. Instead of writing your own solution, you can utilize those.
You can use partial to create a function with a predefined argument like:
from functools import partial
new_task_success_alert = partial(task_success_alert, email='your_email')
And then add the new function as a callback:
on_success_callback=new_task_success_alert
You can create a task that its only purpose is to push configuration setting through xcoms. You can pull the configuration via context as the task_instance object is included in context.
def push_configuration(ti, params):
ti.xcom_push(key='conn_id', value=params)
def task_success_alert(context):
ti = context.get('ti')
params = ti.xcom_pull(key='params', task_ids='Settings')
...
step0 = PythonOperator(
task_id='Settings',
python_callable=push_configuration,
op_kwargs={'params': params})
step1 = BashOperator(
task_id='step1',
bash_command='pwd',
on_success_callback=task_success_alert)

Categories

Resources