I am trying to read run_id for DAG in SnowflakeOperator to set a session parameter, query_tag. But it seems like the session parameter is not templated.
snowflake_operator = SnowflakeOperator(
task_id='snowflake_task',
snowflake_conn_id='snowflake_conn',
sql='resources/some.sql',
warehouse='my_warehouse',
database='my_database',
role='my_role',
session_parameters={
'QUERY_TAG': 'dagrun_{{run_id}}'
}
)
How can I reference run_id and use it as an input here?
You need to make the non-templated field templated.
class MySnowflakeOperator(SnowflakeOperator):
template_fields = (
"session_parameters",
) + SnowflakeOperator.template_fields
Then you can use it as:
snowflake_operator = MySnowflakeOperator(
task_id='snowflake_task',
snowflake_conn_id='snowflake_conn',
sql='resources/some.sql',
warehouse='my_warehouse',
database='my_database',
role='my_role',
session_parameters={
'QUERY_TAG': 'dagrun_{{run_id}}'
}
)
Related
I need to count the number of rows in a table and use the row count in the filename of an export to GCS. The following is an excerpt from my DAG.
with models.DAG(
'my_dag',
schedule_interval = '0 6 * * 1',
start_date = datetime(2022, 1, 1),
catchup = False
) as dag:
# create segment filterd views to output CSV to GCS
def prepareSegmentTables(segment, **kwargs):
segment_table_queries = f"""
TRUNCATE TABLE dataset.some_table;
INSERT INTO dataset.some_table (column1)
SELECT DISTINCT column1
FROM dataset.some_other_table
WHERE column2 = '{ segment['id'] }';
"""
# execute query
client.query(segment_table_queries).result()
# store the row counts of each type
kwargs['ti'].xcom_push(
key = "ROW_COUNTS",
value = {
"column1": getTableRowCount("dataset.some_table"),
}
)
def get_row_counts(segment, **kwargs):
ROW_COUNTS = kwargs['ti'].xcom_pull(
key = "ROW_COUNTS",
task_ids = [ f"prepare_segment_tables" ]
)
#tasks
prepare_segment_tables = PythonOperator(
task_id = f"prepare_segment_tables",
python_callable = prepareSegmentTables,
op_kwargs = { "segment": segment },
dag = dag
)
export_to_gcs = BigQueryToCloudStorageOperator(
task_id = f"gcs_lr_to_li_auid_{segment['id']}",
source_project_dataset_table = f"{GCP_PROJECT}.{DATASET_NAME}.some_table",
destination_cloud_storage_uris = f"gs://{GCS_BUCKET}/{FILENAME_PATH}{segment['name']}_"
+ str( ti.xcom_pull(key = "ROW_COUNTS", task_ids = [ f"prepare_segment_tables" ])[0].column1 )
+ f"_{TODAY_STR}.csv",
# this works though
# destination_cloud_storage_uris = f"gs://{GCS_BUCKET}/{FILENAME_PATH}{segment['name']}_" + str( getTableRowCount("dataset.some_table") ) + f"_{TODAY_STR}.csv",
compression = 'NONE', export_format = 'CSV', field_delimiter = ',', print_header = True
)
prepare_segment_tables >> export_to_gcs
As can be seen, I am pushing ROW_COUNTS into xcom while calling prepareSegmentTables via a PythonOperator. When I do xcom_pull inside another PythonOperator, calling get_row_counts, it properly pulls the value, but when I pass the same syntax as a parameter to BigQueryToCloudStorageOperator or BigQueryToGCSOperator, it throws an error.
It says ti or kwargs['ti'], depending on what I use is undefined. Some people suggest using double {{ }}, and even that didn't work for me.
For now, I have resorted to calling getTableRowCount() directly, in the parameter instead of first storing it in a variable. It works, but I use the filename downstream at least one more time, and this approach results in unnecessarily querying the table for a row count multiple times.
Any help getting xcom to work or to figure out a way to get row count in the filename efficiently is appreciated.
Airflow is a distributed system. It is important to note that your DAG code isn't executed all in the same context.
The DAG is parsed and assembled on a schedule. The Task is executed on a worker. XCOM is available inside the worker context.
As you saw - ti or kwargs['ti'] would allow you to access XCOMs from inside a PythonOperator (specifically the python_callable) but in your BigQueryToCloudStorageOperator, you don't have task_instance available, as you aren't in that context.
You can use Jinja Templating to defer fetching the XCOM until you are in the correct context with the worker (read more here https://airflow.apache.org/docs/apache-airflow/stable/templates-ref.html )
You probably need something like:
BigQueryToCloudStorageOperator(
...
destination_cloud_storage_uris = f"gs://{GCS_BUCKET}/{FILENAME_PATH}{segment['name']}_" + "{{ ti.xcom_pull(key = "ROW_COUNTS", task_ids=["prepare_segment_tables"])[0].column1 ) }}" + f"_{TODAY_STR}.csv",
...
)
Note that I am being careful to not mix f-strings and jinja there, as they both utilize the {} syntax and don't play well together.
You will also probably need to ensure that your XCOM will parse as you expect in Jinja, and may want to return your XCOM differently or parse it in advanced with a PythonOperator
See here for help with what you can do within a Jinja template: https://jinja.palletsprojects.com/
Note also that this works only because "destination_cloud_storage_uris" is a templated field on this operator - and not all fields are.
https://airflow.apache.org/docs/apache-airflow/1.10.15/_modules/airflow/contrib/operators/bigquery_to_gcs.html#BigQueryToCloudStorageOperator
I have an SSHOperator that writes a filepath to stdout. I'd like to get the os.path.basename of that filepath so that I can pass it as a parameter to my next task (which is an sftp pull). The idea is to download a remote file into the current working directory. This is what I have so far:
with DAG('my_dag',
default_args = dict(...
xcom_push = True,
)
) as dag:
# there is a get_update_id task here, which has been snipped for brevity
get_results = SSHOperator(task_id = 'get_results',
ssh_conn_id = 'my_remote_server',
command = """cd ~/path/to/dir && python results.py -t p -u {{ task_instance.xcom_pull(task_ids='get_update_id') }}""",
cmd_timeout = -1,
)
download_results = SFTPOperator(task_id = 'download_results',
ssh_conn_id = 'my_remote_server',
remote_filepath = base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }}"""),
local_filepath = os.path.basename(base64.b64decode("""{{ task_instance.xcom_pull(task_ids='get_results') }""").decode()),
operation = 'get',
)
Airflow tells me there's an error on the remote_filepath = line. Investigating this further, I see that the value passed to base64.b64decode is not the xcom value from the get_results task, but is rather the raw string starting with {{.
My feeling is that since tasks are templated, there's some under-the-hood magic to resolve the templated string. Whereas this is not exactly supported by os.path.basename. So would I need to create an intermediate task to get the basename? Is there no way to shorthand this the way I've tried?
I'd appreciate any help on this
You want to decode the XCOM return value when Airflow renders the remote_filepath property for the Task instance.
This means that the b64decode function must be invoked within the template string.
There is a catch though, we have to make this function available in the template context by providing it as a parameter or on the DAG level as a user defined filter or macro.
def basename_b64decode(value):
return os.path.basename(base64.b64decode(value)).decode()
download_results = SFTPOperator(
task_id = 'download_results',
ssh_conn_id = 'my_remote_server',
remote_filepath = """{{params.b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
local_filepath = """{{params.basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
operation = 'get',
params = {
'b64decode': base64.b64decode
'basename_b64decode': basename_b64decode
}
)
For the DAG user-defined macro approach, you can write:
with DAG('my_dag',
default_args = dict(...
xcom_push = True,
user_defined_macros=dict(
basename_b64decode=basename_b64decode,
b64decode=base64.b64decode
)
)
) as dag:
download_results = SFTPOperator(
task_id = 'download_results',
ssh_conn_id = 'my_remote_server',
remote_filepath = """{{b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
local_filepath = """{{basename_b64decode(ti.xcom_pull(task_ids='get_results'))}}""",
operation = 'get',
)
Sometimes I find it handy to create tasks using a loop.
Below is an example of a SqoopOperator of which I use the xcom value from the previous PythonOperator in the where clause. I am trying to use a variable get_delivery_sqn_task_id to access the correct xcom value ti.xcom_pull(task_ids=get_delivery_sqn_task_id , however this does not work (returns ()).
I can take everything out of the loop, but this makes the code quite ugly I think. Is there an elegant solution to have a variable task_ids to retrieve xcom values? I guess otherwise the best solution is using the Airflow Variables.
for table in tables:
get_delivery_sqn_task_id ='get_delivery_sqn_'+ table
get_delivery_sqn_task = PythonOperator(
task_id = get_delivery_sqn_task_id,
python_callable = get_delivery_sqn,
op_kwargs = {
'table_name': table
},
provide_context = True,
dag = dag
)
sqoop_operator_task = SqoopOperator(
task_id = "sqoop_"+table,
conn_id = "DWDH_PROD",
table = table,
cmd_type = "import",
target_dir = "/sourcedata/sqoop_tmp/"+table,
num_mappers = 1,
where = "delivery_sqn > {{ ti.xcom_pull(task_ids=get_delivery_sqn_task_id, key='return_value') }}",
dag = dag
)
You can do:
"delivery_sqn > {{{{ ti.xcom_pull(task_ids={}, key='return_value') }}}}".format(get_delivery_sqn_task_id)
I'm having a hard time trying to figure out how to make the code below to take InstanceID as parameter and display info only about this instance:
import boto3
ec2_client = boto3.client('ec2')
instance_data = []
instanceid = 'i-123456'
def get_instace_data(instanceid):
reservations = ec2_client.describe_instances()
for reservation in reservations['Reservations']:
for instance in reservation['Instances']:
instance_data.append(
{
'instanceId': instance['InstanceId'],
'instanceType': instance['InstanceType'],
'launchDate': instance['LaunchTime'].strftime('%Y-%m-%dT%H:%M:%S.%f')
}
)
print(instance_data)
get_instace_data(instanceid)
My understanding is that get_instance_data() should be taking the instance ID and display only info for that instance, even though it iterates to the whole 'Reservations' dict that returned.
Am i missing something here?
Thanks in advance!
Your code will append all instances to the instance_data.
In order to receive only required instance you need change
reservations = ec2_client.describe_instances()
To the following:
reservations = ec2_client.describe_instances(
Filters=
InstanceIds=['i-123456'])
describe_instance() let you specify filters, however, the syntax can be quite confusing. Because you can specify instance ID inside the Filters parameter or inside InsanceIds parameter.
e.g.
# method 1
response = client.describe_instances(
InstanceIds=[
'i-123456',
],
)
# method 2
response = client.describe_instances(
Filters=[
{
'Name': 'instance-id',
'Values': ['i-123456']
},
],
)
The documentation showing Filters (list) is utterly confusing, because the only way to pass in the correct filtering name into the list, you must embrace the "Name" and "Value" explicitly inside a dict.
# This will not works
response = client.describe_instances(
Filters=[
'instance-id:i-123456'
])
# Neither will this!
response = client.describe_instances(
Filters=[
{'instance-id': ['i-123456']}
])
I have simple REST/json api. I am trying to update a mongoengine model Listing using a PUT call. I get one from the mongodb and update it using my own deserilize method with the incoming json. The update does not work as the json has a few DBRef and EmbeddedDocuments. The object does get updated with the correct values before the save is executed. There is no error but the object does not get saved. Any thoughts?
obj = Listing.objects.get(pk=id)
obj.deserialize(**request.json)
obj.save()
obj.reload()
return obj
class Listing():
name = db.StringField()
l_type = db.StringField( choices=listing_const.L_TYPE.choices(), )
expiry = db.ComplexDateTimeField( auto_now=False,auto_now_add=False, )
a_data = db.EmbeddedDocumentField(Media)
lcrr = db.ReferenceField( 'LCRR', reverse_delete_rule=3, dbref=False, )
meta = {
'db_alias': config.get_config()['MONGODB_SETTINGS']['alias'],
'cascade':True
}