I need to count the number of rows in a table and use the row count in the filename of an export to GCS. The following is an excerpt from my DAG.
with models.DAG(
'my_dag',
schedule_interval = '0 6 * * 1',
start_date = datetime(2022, 1, 1),
catchup = False
) as dag:
# create segment filterd views to output CSV to GCS
def prepareSegmentTables(segment, **kwargs):
segment_table_queries = f"""
TRUNCATE TABLE dataset.some_table;
INSERT INTO dataset.some_table (column1)
SELECT DISTINCT column1
FROM dataset.some_other_table
WHERE column2 = '{ segment['id'] }';
"""
# execute query
client.query(segment_table_queries).result()
# store the row counts of each type
kwargs['ti'].xcom_push(
key = "ROW_COUNTS",
value = {
"column1": getTableRowCount("dataset.some_table"),
}
)
def get_row_counts(segment, **kwargs):
ROW_COUNTS = kwargs['ti'].xcom_pull(
key = "ROW_COUNTS",
task_ids = [ f"prepare_segment_tables" ]
)
#tasks
prepare_segment_tables = PythonOperator(
task_id = f"prepare_segment_tables",
python_callable = prepareSegmentTables,
op_kwargs = { "segment": segment },
dag = dag
)
export_to_gcs = BigQueryToCloudStorageOperator(
task_id = f"gcs_lr_to_li_auid_{segment['id']}",
source_project_dataset_table = f"{GCP_PROJECT}.{DATASET_NAME}.some_table",
destination_cloud_storage_uris = f"gs://{GCS_BUCKET}/{FILENAME_PATH}{segment['name']}_"
+ str( ti.xcom_pull(key = "ROW_COUNTS", task_ids = [ f"prepare_segment_tables" ])[0].column1 )
+ f"_{TODAY_STR}.csv",
# this works though
# destination_cloud_storage_uris = f"gs://{GCS_BUCKET}/{FILENAME_PATH}{segment['name']}_" + str( getTableRowCount("dataset.some_table") ) + f"_{TODAY_STR}.csv",
compression = 'NONE', export_format = 'CSV', field_delimiter = ',', print_header = True
)
prepare_segment_tables >> export_to_gcs
As can be seen, I am pushing ROW_COUNTS into xcom while calling prepareSegmentTables via a PythonOperator. When I do xcom_pull inside another PythonOperator, calling get_row_counts, it properly pulls the value, but when I pass the same syntax as a parameter to BigQueryToCloudStorageOperator or BigQueryToGCSOperator, it throws an error.
It says ti or kwargs['ti'], depending on what I use is undefined. Some people suggest using double {{ }}, and even that didn't work for me.
For now, I have resorted to calling getTableRowCount() directly, in the parameter instead of first storing it in a variable. It works, but I use the filename downstream at least one more time, and this approach results in unnecessarily querying the table for a row count multiple times.
Any help getting xcom to work or to figure out a way to get row count in the filename efficiently is appreciated.
Airflow is a distributed system. It is important to note that your DAG code isn't executed all in the same context.
The DAG is parsed and assembled on a schedule. The Task is executed on a worker. XCOM is available inside the worker context.
As you saw - ti or kwargs['ti'] would allow you to access XCOMs from inside a PythonOperator (specifically the python_callable) but in your BigQueryToCloudStorageOperator, you don't have task_instance available, as you aren't in that context.
You can use Jinja Templating to defer fetching the XCOM until you are in the correct context with the worker (read more here https://airflow.apache.org/docs/apache-airflow/stable/templates-ref.html )
You probably need something like:
BigQueryToCloudStorageOperator(
...
destination_cloud_storage_uris = f"gs://{GCS_BUCKET}/{FILENAME_PATH}{segment['name']}_" + "{{ ti.xcom_pull(key = "ROW_COUNTS", task_ids=["prepare_segment_tables"])[0].column1 ) }}" + f"_{TODAY_STR}.csv",
...
)
Note that I am being careful to not mix f-strings and jinja there, as they both utilize the {} syntax and don't play well together.
You will also probably need to ensure that your XCOM will parse as you expect in Jinja, and may want to return your XCOM differently or parse it in advanced with a PythonOperator
See here for help with what you can do within a Jinja template: https://jinja.palletsprojects.com/
Note also that this works only because "destination_cloud_storage_uris" is a templated field on this operator - and not all fields are.
https://airflow.apache.org/docs/apache-airflow/1.10.15/_modules/airflow/contrib/operators/bigquery_to_gcs.html#BigQueryToCloudStorageOperator
Related
I want to call a TaskGroup with a Dynamic sub-task id from BranchPythonOperator.
This is the DAG flow that I have:
branch_dag
My case is I want to check whether a table exists in BigQuery or not.
If exists: do nothing and end the DAG
If not exists: Ingest the data from Postgres to Google Cloud Storage
I know that to call a TaskGroup from BranchPythonOperator is by calling the task id with following format:
group_task_id.task_id
The problem is, my task group's sub task id is dynamic, depends on how many time I loop the TaskGroup. So the sub_task will be:
parent_task_id.sub_task_1
parent_task_id.sub_task_2
parent_task_id.sub_task_3
...
parent_task_id.sub_task_x
This is the following code for the DAG that I have:
import airflow
from airflow.providers.google.cloud.transfers.postgres_to_gcs import PostgresToGCSOperator
from airflow.utils.task_group import TaskGroup
from google.cloud.exceptions import NotFound
from airflow import DAG
from airflow.operators.python import BranchPythonOperator
from airflow.operators.dummy import DummyOperator
from google.cloud import bigquery
default_args = {
'owner': 'Airflow',
'start_date': airflow.utils.dates.days_ago(2),
}
with DAG(dag_id='branch_dag', default_args=default_args, schedule_interval=None) as dag:
def create_task_group(worker=1):
var = dict()
with TaskGroup(group_id='parent_task_id') as tg1:
for i in range(worker):
var[f'sub_task_{i}'] = PostgresToGCSOperator(
task_id = f'sub_task_{i}',
postgres_conn_id = 'some_postgres_conn_id',
sql = 'test.sql',
bucket = 'test_bucket',
filename = 'test_file.json',
export_format = 'json',
gzip = True,
params = {
'worker': worker
}
)
return tg1
def is_exists_table():
client = bigquery.Client()
try:
table_name = client.get_table('dataset_id.some_table')
if table_name:
return 'task_end'
except NotFound as error:
return 'parent_task_id'
task_start = DummyOperator(
task_id = 'start'
)
task_branch_table = BranchPythonOperator(
task_id ='check_table_exists_in_bigquery',
python_callable = is_exists_table
)
task_pg_to_gcs_init = create_task_group(worker=3)
task_end = DummyOperator(
task_id = 'end',
trigger_rule = 'all_done'
)
task_start >> task_branch_table >> task_end
task_start >> task_branch_table >> task_pg_to_gcs_init >> task_end
When I run the dag, it returns
**airflow.exceptions.TaskNotFound: Task parent_task_id not found **
But this is expected, what I don't know is how to iterate the parent_task_id.sub_task_x on is_exists_table function. Or are there any workaround?
This is the test.sql file if it's needed
SELECT
id,
name,
country
FROM some_table
WHERE 1=1
AND ABS(MOD(hashtext(id::TEXT), 3)) = {{params.worker}};
-- returns 1M+ rows
I already seen this question as reference Question but I think my case is more specific.
When designing your data pipelines, you may encounter use cases that require more complex task flows than "Task A > Task B > Task C." For example, you may have a use case where you need to decide between multiple tasks to execute based on the results of an upstream task. Or you may have a case where part of your pipeline should only run under certain external conditions. Fortunately, Airflow has multiple options for building conditional logic and/or branching into your DAGs.
I found a dirty way around it.
What I did is creating 1 additional task using DummyOperator called task_pass.
task_pass = DummyOperator(
task_id = 'pass_to_task_group'
)
So the DAG flow now looks like this:
task_start >> task_branch_table >> task_end
task_start >> task_branch_table >> task_pass >> task_pg_to_gcs_init >> task_end
Also there is 1 mistake that I made from the code above, notice that the params I set was worker. This is wrong because worker is the constant while the thing that I need to iterate is the i variable. So I change it from:
params: worker
to:
params: i
I am trying to read run_id for DAG in SnowflakeOperator to set a session parameter, query_tag. But it seems like the session parameter is not templated.
snowflake_operator = SnowflakeOperator(
task_id='snowflake_task',
snowflake_conn_id='snowflake_conn',
sql='resources/some.sql',
warehouse='my_warehouse',
database='my_database',
role='my_role',
session_parameters={
'QUERY_TAG': 'dagrun_{{run_id}}'
}
)
How can I reference run_id and use it as an input here?
You need to make the non-templated field templated.
class MySnowflakeOperator(SnowflakeOperator):
template_fields = (
"session_parameters",
) + SnowflakeOperator.template_fields
Then you can use it as:
snowflake_operator = MySnowflakeOperator(
task_id='snowflake_task',
snowflake_conn_id='snowflake_conn',
sql='resources/some.sql',
warehouse='my_warehouse',
database='my_database',
role='my_role',
session_parameters={
'QUERY_TAG': 'dagrun_{{run_id}}'
}
)
I have an SSHOperator that writes a filepath to stdout. I'd like to get the os.path.basename of that filepath so that I can pass it as a parameter to my next task (which is an sftp pull). The idea is to download a remote file into the current working directory. This is what I have so far:
with DAG('my_dag',
default_args = dict(...
xcom_push = True,
)
) as dag:
# there is a get_update_id task here, which has been snipped for brevity
get_results = SSHOperator(task_id = 'get_results',
ssh_conn_id = 'my_remote_server',
command = """cd ~/path/to/dir && python results.py -t p -u {{ task_instance.xcom_pull(task_ids='get_update_id') }}""",
cmd_timeout = -1,
)
download_results = SFTPOperator(task_id = 'download_results',
ssh_conn_id = 'my_remote_server',
remote_filepath = base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }}"""),
local_filepath = os.path.basename(base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }""").decode()),
operation = 'get',
)
Airflow tells me there's an error on the remote_filepath = line. Investigating this further, I see that the value passed to base64.b64decode is not the xcom value from the get_results task, but is rather the raw string starting with {{.
My feeling is that since tasks are templated, there's some under-the-hood magic to resolve the templated string. Whereas this is not exactly supported by os.path.basename. So would I need to create an intermediate task to get the basename? Is there no way to shorthand this the way I've tried?
I'd appreciate any help on this
You want to decode the XCOM return value when Airflow renders the remote_filepath property for the Task instance.
This means that the b64decode function must be invoked within the template string.
There is a catch though, we have to make this function available in the template context by providing it as a parameter or on the DAG level as a user defined filter or macro.
def basename_b64decode(value):
return os.path.basename(base64.b64decode(value)).decode()
download_results = SFTPOperator(
task_id = 'download_results',
ssh_conn_id = 'my_remote_server',
remote_filepath = """{{params.b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
local_filepath = """{{params.basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
operation = 'get',
params = {
'b64decode': base64.b64decode
'basename_b64decode': basename_b64decode
}
)
For the DAG user-defined macro approach, you can write:
with DAG('my_dag',
default_args = dict(...
xcom_push = True,
user_defined_macros=dict(
basename_b64decode=basename_b64decode,
b64decode=base64.b64decode
)
)
) as dag:
download_results = SFTPOperator(
task_id = 'download_results',
ssh_conn_id = 'my_remote_server',
remote_filepath = """{{b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
local_filepath = """{{basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
operation = 'get',
)
I have a dag that has a structure like:
def do_thing(**kwargs):
lib.do_thing(kwargs['thing_type'])
do_thing_type_a = PythonOperator(
task_id = "do_thing_type_a",
python_callable = do_thing,
op_kwargs = {"thing_type": "A"}
)
do_thing_type_b = PythonOperator(
task_id = "do_thing_type_b",
python_callable = do_thing,
op_kwargs = {"thing_type": "B"}
)
do_thing_type_a
do_thing_type_b
The real DAG is more complex and has other functions that are downstream of do_thing_type_a and do_thing_type_b but this I think demonstrates the particular question.
Essentially in the same DAG I'd like to execute the same function twice but with a different value for the parameter. The simple way as I've done in this example is to create two tasks - one for each type - and run them both but this feels like it violates basic DRY principles and affects maintainability as if I want to make a change to the task I'll have to make the change in both versions of the task.
Ideally I would like to be able to define one task such as do_thing_type and then pass a parameter like 'A' or 'B' when creating the dependencies, but I don't know if there's anyway to do this.
What is the best approach for this in Airflow?
You can do that in a loop:
def do_thing(**kwargs):
lib.do_thing(kwargs['thing_type'])
items_dict = {"a": "A", "b": "B"}
for key, val in items_dict.items():
do_thing_type_a = PythonOperator(
task_id = f"do_thing_type_{key}",
python_callable = do_thing,
op_kwargs = {"thing_type": val}
)
Sometimes I find it handy to create tasks using a loop.
Below is an example of a SqoopOperator of which I use the xcom value from the previous PythonOperator in the where clause. I am trying to use a variable get_delivery_sqn_task_id to access the correct xcom value ti.xcom_pull(task_ids=get_delivery_sqn_task_id , however this does not work (returns ()).
I can take everything out of the loop, but this makes the code quite ugly I think. Is there an elegant solution to have a variable task_ids to retrieve xcom values? I guess otherwise the best solution is using the Airflow Variables.
for table in tables:
get_delivery_sqn_task_id ='get_delivery_sqn_'+ table
get_delivery_sqn_task = PythonOperator(
task_id = get_delivery_sqn_task_id,
python_callable = get_delivery_sqn,
op_kwargs = {
'table_name': table
},
provide_context = True,
dag = dag
)
sqoop_operator_task = SqoopOperator(
task_id = "sqoop_"+table,
conn_id = "DWDH_PROD",
table = table,
cmd_type = "import",
target_dir = "/sourcedata/sqoop_tmp/"+table,
num_mappers = 1,
where = "delivery_sqn > {{ ti.xcom_pull(task_ids=get_delivery_sqn_task_id, key='return_value') }}",
dag = dag
)
You can do:
"delivery_sqn > {{{{ ti.xcom_pull(task_ids={}, key='return_value') }}}}".format(get_delivery_sqn_task_id)