Querying Athena from Lambda function - QUEUED state? - python

I've been successfully querying s3 via athena from inside a lambda function for quite some time but it has suddenly stopped working. Further investigation shows that the response from get_query_execution() is returned a state of 'QUEUED' (which i was led to believe is not used?!)
My code is as follows:
def run_query(query, database, s3_output, max_execution=5):
response = client.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': database
},
ResultConfiguration={
'OutputLocation': s3_output
})
execution_id = response['QueryExecutionId']
print("QueryExecutionId = " + str(execution_id))
state = 'RUNNING'
while (max_execution > 0 and state in ['RUNNING']):
max_execution = max_execution - 1
print("maxexecution=" + str(max_execution))
response = client.get_query_execution(QueryExecutionId = execution_id)
if 'QueryExecution' in response and \
'Status' in response['QueryExecution'] and \
'State' in response['QueryExecution']['Status']:
state = response['QueryExecution']['Status']['State']
print(state)
if state == 'SUCCEEDED':
print("Query SUCCEEDED: {}".format(execution_id))
s3_key = 'athena_output/' + execution_id + '.csv'
print(s3_key)
local_filename = '/tmp/' + execution_id + '.csv'
print(local_filename)
rows = []
try:
print("s3key =" + s3_key)
print("localfilename = " + local_filename)
s3.Bucket(BUCKET).download_file(s3_key, local_filename)
with open(local_filename) as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
rows.append(row)
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == "404":
print("The object does not exist.")
print(e)
else:
raise
return json.dumps(rows)
elif state == 'FAILED':
return False
time.sleep(10)
return False
So it obviously is working as it should be - it's just that the 'QUEUED' state is completely unexpected and i'm not sure what to do about it? What can cause the query_execution to become 'QUEUED' and what needs to change in my code to cater for it?

Take a look on Athena hook in Apache Airflow. Athena has final states (SUCCEEDED, FAILED and CANCELLED) and intermediate states - RUNNING and QUEUED. QUEUED is a normal state for a query before it got stared. So you could use code like this:
def run_query(query, database, s3_output, max_execution=5):
response = client.start_query_execution(
QueryString=query,
QueryExecutionContext={
'Database': database
},
ResultConfiguration={
'OutputLocation': s3_output
})
execution_id = response['QueryExecutionId']
print("QueryExecutionId = " + str(execution_id))
state = 'QUEUED'
while (max_execution > 0 and state in ['RUNNING', 'QUEUED']):
max_execution = max_execution - 1
print("maxexecution=" + str(max_execution))
response = client.get_query_execution(QueryExecutionId = execution_id)
if 'QueryExecution' in response and \
'Status' in response['QueryExecution'] and \
'State' in response['QueryExecution']['Status']:
state = response['QueryExecution']['Status']['State']
print(state)
if state == 'SUCCEEDED':
print("Query SUCCEEDED: {}".format(execution_id))
s3_key = 'athena_output/' + execution_id + '.csv'
print(s3_key)
local_filename = '/tmp/' + execution_id + '.csv'
print(local_filename)
rows = []
try:
print("s3key =" + s3_key)
print("localfilename = " + local_filename)
s3.Bucket(BUCKET).download_file(s3_key, local_filename)
with open(local_filename) as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
rows.append(row)
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == "404":
print("The object does not exist.")
print(e)
else:
raise
return json.dumps(rows)
elif state == 'FAILED' or state == 'CANCELLED':
return False
time.sleep(10)
return False

Got this response from AWS - there has been changes to Athena that caused this issue (although QUEUED has been in the state enum for some time is hasn't been used until now):
The Athena team recently deployed a host of new functionality for Athena, including more granular CloudWatch metrics for Athena queries.
For more information:
AWS What's New page
Athena docs on CloudWatch metrics
As part of the deployment of more granular metrics, Athena now includes a QUEUED status for queries. This status indicates that an Athena query is waiting for resources to be allocated for processing. Query flow is roughly:
SUBMITTED -> QUEUED -> RUNNING -> COMPLETED/FAILED
Note that queries that fail due to system errors can be put back into the queue and retried.
I apologise for the frustration that this change has caused.
It seems like the forum formatting has stripped some elements from your code snippets.
However, I think that your WHILE loop is working on an array of the possible query statuses, which didn't previously cater for QUEUED.
If that is the case, then yes, adding QUEUED to that array will allow your application to handle the new status.

Related

How do i run an AWS lambda function on a schedule depending on the output of my code

So the purpose of my lambda function is to check if a success file is in a certain s3 bucket path. If it's not there then it should send a failure message to slack. If it is there then it should trigger a dag to run manually. So at the moment I've got the code working for that but I want to retry the function IF there is no success file for the next 10 hours every 15 minutes. If after all those attempts, it's still not there then I want to send the failure notification to slack. At the moment, my event bridge cron expression schedule run of the lambda is spamming the slack failure message x amount of times or triggering the dag x amount of times. I want to exit the schedule if there is a success file there and trigger the dag once but i want it to run again if its no there and then send a failure message on the last retry that ive got scheduled. Is this possible? My code is below and my eventbridge cron is for mon-friday, every 15 minutes from 4 am to 1pm. Any guidance on a possible solution would be appreciated!
def lambda_handler(event, context):
#slack and airflow config
slack_webhook_url_details = ""
dag_id = 'sample_dag'
airflow_url=''
def send_slack_message(slack_webhook_url,slack_message):
slack_payload = {'text':slack_message}
response = requests.post(slack_webhook_url, json.dumps(slack_payload))
response_json = response.text
print('response after posting to slack: '+ str(response_json))
def initialize_paths():
global bucket
global path
global dt
two_days_ago = date.today() - timedelta(days=2)
dt = two_days_ago.strftime('%Y%m%d')
#add for loop logic for iterating through countries in bucket
bucket = ""
path = f"test_lambda/US/{dt}/_SUCCESS"
def check_file_exists():
print(bucket,path)
s3 = resource('s3')
try:
#check if file exists
s3.Object(bucket, path).load()
logging.info(f'_SUCCESS file exists with at path: {bucket}/{path} for the following date: {dt}')
#trigger dag run
mwaa_env_name = 'airflow-prod-env'
dag_name= 'sample_dag'
mwaa_cli_command = 'dags trigger'
client = boto3.client('mwaa')
mwaa_cli_token = client.create_cli_token(Name=mwaa_env_name)
conn = http.client.HTTPSConnection(mwaa_cli_token['WebServerHostname'])
payload = mwaa_cli_command + " " + dag_name
headers = {
'Authorization': 'Bearer ' + mwaa_cli_token['CliToken'],
'Content-Type': 'text/plain'
}
conn.request("POST", "/aws_mwaa/cli", payload, headers)
res = conn.getresponse()
data = res.read()
dict_str = data.decode("UTF-8")
mydata = ast.literal_eval(dict_str)
return base64.b64decode(mydata['stdout'])
except botocore.exceptions.ClientError as errorStdOut:
if errorStdOut.response['Error']['Code'] >= "401":
error_message= f'_SUCCESS file NOT detected at path: {bucket}/{path} for the following date: {dt}'
logging.info(error_message)
send_slack_message(slack_webhook_url_details,error_message)
else:
logging.info('Error, something went wrong connecting to lambda')
initialize_paths()
check_file_exists()```

Python | Multithreading script does not complete

I have this multithreading script, which operates on a data set. Each thread gets a chunk of the data set and then each thread iterates over the data frame and calls and api (MS Graph Create).
What I have seen is that, my script tends to get stuck at almost finish time. I am running this on a linux Ubuntu server. 8vCpus. But this happens only when the total dataset size is in millions. (takes around 9-10 hrs for 2 million records)
I am writing a script (long running) for the first time. Would like to get an opinion if I am doing things correctly.
Please :
I would like to know if my code is the reason why my script hangs.
Have I done multithreading correctly ? Have I created and waited for threads to end correctly ?
UPDATE
Using answers, below, still the threads seems to get stuck at the end.
import pandas as pd
import sys
import os
import logging
import string
import secrets
import random
##### ----- Logging Setup -------
logging.basicConfig(filename="pylogs.log", format='%(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
# Creating an object
logger = logging.getLogger()
# Setting the threshold of logger to DEBUG
logger.setLevel(logging.ERROR)
#####------ Function Definitions -------
# generates random password
def generateRandomPassword(lengthOfPassword):
# logic for random password gen
# the most important funtion
#
def createAccounts(splitData, threadID):
batchProgress = 0
batch_size = splitData.shape[0]
for row in splitData.itertuples():
try:
headers = {"Content-Type": "application/json", "Authorization":"Bearer "+access_token}
randomLength = [8,9,12,13,16]
passwordLength = random.choice(randomLength)
password = generateRandomPassword(passwordLength) # will be generated randomly - for debugging purpose
batchProgress+=1
post_request_body = {
"accountEnabled": True,
"displayName": row[5],
"givenName": row[3],
"surname": row[4],
"mobilePhone": row[1],
"mail": row[2],
"passwordProfile" : {
"password": password,
"forceChangePasswordNextSignIn": False
},
"state":"",
"identities": [
{
"signInType": "emailAddress",
"issuer": tenantName,
"issuerAssignedId": row[2]
}
]
}
# if phone number exists then only add - since phone number needs to have length between 1 and 64, cannot leave empty
if(len(row[4])):
post_request_body["identities"].append({"signInType": "phoneNumber","issuer": tenantName,"issuerAssignedId": row[1]})
responseFromApi = requests.post(graph_api_create, headers=headers, json=post_request_body)
status = responseFromApi.status_code
if(status == 201): #success
id = responseFromApi.json().get("id")
print(f" {status} | {batchProgress} / {batch_size} | Success {id}")
errorDict = f'{row[1]}^{row[2]}^{row[3]}^{row[4]}^{row[5]}^{row[6]}^{row[7]}^{row[8]}^{row[9]}^{row[10]}{row[11]}{row[12]}{row[13]}{row[11]}{row[12]}{row[13]}^Success'
elif(status == 429): #throttling issues
print(f" Thread {threadID} | Throttled by server ! Sleeping for 150 seconds")
errorDict = f'{row[1]}^{row[2]}^{row[3]}^{row[4]}^{row[5]}^{row[6]}^{row[7]}^{row[8]}^{row[9]}^{row[10]}{row[11]}{row[12]}{row[13]}^Throttled'
time.sleep(150)
elif(status == 401): #token expiry
print(f" Thread {threadID} | Token Expired. Getting it back !")
errorDict = f'{row[1]}^{row[2]}^{row[3]}^{row[4]}^{row[5]}^{row[6]}^{row[7]}^{row[8]}^{row[9]}^{row[10]}{row[11]}{row[12]}{row[13]}^Token Expired'
getRefreshToken()
else: #any other error
msg = ""
try:
msg = responseFromApi.json().get("error").get("message")
except Exception as e:
msg = f"Error {e}"
errorDict = f'{row[1]}^{row[2]}^{row[3]}^{row[4]}^{row[5]}^{row[6]}^{row[7]}^{row[8]}^{row[9]}^{row[10]}{row[11]}{row[12]}{row[13]}^{msg}'
print(f" {status} | {batchProgress} / {batch_size} | {msg} {row[2]}")
logger.error(errorDict)
except Exception as e:
# check for refresh token errors
errorDict = f'{row[1]}^{row[2]}^{row[3]}^{row[4]}^{row[5]}^{row[6]}^{row[7]}^{row[8]}^{row[9]}^{row[10]}{row[11]}{row[12]}{row[13]}^Exception_{e}'
logger.error(errorDict)
msg = " Error "
print(f" {status} | {batchProgress} / {batch_size} | {msg} {row[2]}")
print(f"Thread {threadID} completed ! {batchProgress} / {batch_size}")
batchProgress = 0
###### ------ Main Script ------
if __name__ == "__main__":
# get file name and appid from command line arguments
storageFileName = sys.argv[1]
appId = sys.argv[2]
# setup credentials
bigFilePath = f"./{storageFileName}"
CreatUserUrl = "https://graph.microsoft.com/v1.0/users"
B2C_Tenant_Name = "tenantName"
tenantName = B2C_Tenant_Name + ".onmicrosoft.com"
applicationID = appId
accessSecret = "" # will be taken from command line in future revisions
token_api_body = {
"grant_type": "client_credentials",
"scope": "https://graph.microsoft.com/.default",
"client_Id" : applicationID,
"client_secret": accessSecret
}
# Get initial access token from MS
print("Connecting to MS Graph API")
token_api = "https://login.microsoftonline.com/"+tenantName+"/oauth2/v2.0/token"
response = {}
try:
responseFromApi = requests.post(token_api, data=token_api_body)
responseJson = responseFromApi.json()
print(f"Token API Success ! Expires in {responseJson.get('expires_in')} seconds")
except Exception as e:
print("ERROR | Token auth failed ")
# if we get the token proceed else abort
if(responseFromApi.status_code == 200):
migrationData = pd.read_csv(bigFilePath)
print(" We got the data from Storage !", migrationData.shape[0])
global access_token
access_token = responseJson.get('access_token')
graph_api_create = "https://graph.microsoft.com/v1.0/users"
dataSetSize = migrationData.shape[0]
partitions = 50 # No of partitions # will be taken from command line in future revisions
size = int(dataSetSize/partitions) # No of rows per file
remainder = dataSetSize%partitions
print(f"Data Set Size : {dataSetSize} | Per file size = {size} | Total Files = {partitions} | Remainder: {remainder} | Start...... \n")
##### ------- Dataset partioning.
datasets = []
range_val = partitions + 1 if remainder !=0 else partitions
for partition in range(range_val):
if(partition == partitions):
df = migrationData[size*partition:dataSetSize]
else:
df = migrationData[size*partition:size*(partition+1)]
datasets.append(df)
number_of_threads = len(datasets)
start_time = time.time()
spawned_threads = []
######## ---- Threads are spawned ! here --------
for i in range(number_of_threads): # spawn threads
t = threading.Thread(target=createAccounts, args=(datasets[i], i))
t.start()
spawned_threads.append(t)
number_spawned = len(spawned_threads)
print(f"Started {number_spawned} threads !")
###### - Threads are killed here ! ---------
for thread in spawned_threads: # let the script wait for thread execution
thread.join()
print(f"Done! It took {time.time() - start_time}s to execute") # time check
#### ------ Retry Mechanism -----
print("RETRYING....... !")
os.system(f'python3 retry.py pylogs.log {appId}')
else:
print(f"Token Missing ! API response {responseJson}")```
Here's a refactoring of your code to use the standard library multiprocessing.ThreadPool for simplicity.
Naturally I couldn't have tested it since I don't have your data, but the basic idea should work. I removed the logging and retry stuff, since I really couldn't understand why you'd need it (but feel free to add it back); this will attempt to retry each row if the problem appears to be transient.
import random
import sys
import time
from multiprocessing.pool import ThreadPool
import pandas as pd
import requests
sess = requests.Session()
# globals filled in by `main`
tenantName = None
access_token = None
def submit_user_create(row):
headers = {"Content-Type": "application/json", "Authorization": "Bearer " + access_token}
randomLength = [8, 9, 12, 13, 16]
passwordLength = random.choice(randomLength)
password = generateRandomPassword(passwordLength) # will be generated randomly - for debugging purpose
post_request_body = {
"accountEnabled": True,
"displayName": row[5],
"givenName": row[3],
"surname": row[4],
"mobilePhone": row[1],
"mail": row[2],
"passwordProfile": {"password": password, "forceChangePasswordNextSignIn": False},
"state": "",
"identities": [{"signInType": "emailAddress", "issuer": tenantName, "issuerAssignedId": row[2]}],
}
# if phone number exists then only add - since phone number needs to have length between 1 and 64, cannot leave empty
if len(row[4]):
post_request_body["identities"].append({"signInType": "phoneNumber", "issuer": tenantName, "issuerAssignedId": row[1]})
return sess.post("https://graph.microsoft.com/v1.0/users", headers=headers, json=post_request_body)
def get_access_token(tenantName, applicationID, accessSecret):
token_api_body = {
"grant_type": "client_credentials",
"scope": "https://graph.microsoft.com/.default",
"client_Id": applicationID,
"client_secret": accessSecret,
}
token_api = f"https://login.microsoftonline.com/{tenantName}/oauth2/v2.0/token"
resp = sess.post(token_api, data=token_api_body)
if resp.status_code != 200:
raise RuntimeError(f"Token Missing ! API response {resp.content}")
json = resp.json()
print(f"Token API Success ! Expires in {json.get('expires_in')} seconds")
return json["access_token"]
def process_row(row):
while True:
response = submit_user_create(row)
status = response.status_code
if status == 201: # success
id = response.json().get("id")
print(f"Success {id}")
return True
if status == 429: # throttling issues
print(f"Throttled by server ! Sleeping for 150 seconds")
time.sleep(150)
continue
if status == 401: # token expiry?
print(f"Token Expired. Getting it back !")
getRefreshToken() # TODO
continue
try:
msg = response.json().get("error").get("message")
except Exception as e:
msg = f"Error {e}"
print(f" {status} | {msg} {row[2]}")
return False
def main():
global tenantName, access_token
# get file name and appid from command line arguments
bigFilePath = sys.argv[1]
appId = sys.argv[2]
# setup credentials
B2C_Tenant_Name = "tenantName"
tenantName = f"{B2C_Tenant_Name}.onmicrosoft.com"
accessSecret = "" # will be taken from command line in future revisions
access_token = get_access_token(tenantName, appId, accessSecret)
migrationData = pd.read_csv(bigFilePath)
start_time = time.time()
with ThreadPool(10) as pool:
for i, result in enumerate(pool.imap_unordered(process_row, migrationData.itertuples()), 1):
progress = i / len(migrationData) * 100
print(f"{i} / {len(migrationData)} | {progress:.2f}% | {time.time() - start_time:.2f} seconds")
print(f"Done! It took {time.time() - start_time}s to execute")
if __name__ == "__main__":
main()
un-fair use of MS Graph
Due to possible throttling by the server, the usage of the MS Graph resource might be un-fair between threads. I use fair in the resource starvation sense.
elif(status == 429): #throttling issues
print(f" Thread {threadID} | Throttled by server ! Sleeping for 150 seconds")
errorDict = f'{row[1]}^{row[2]}^{row[3]}^{row[4]}^{row[5]}^{row[6]}^{row[7]}^{row[8]}^{row[9]}^{row[10]}{row[11]}{row[12]}{row[13]}^Throttled'
time.sleep(150)
One thread making a million calls can get a disproportionate amount of 429 responses each followed by a penalty of 150 seconds. This sleep doesn't stop the other threads from making calls though and achieving forward progress.
This would result in one thread lagging far behind the others and giving the appearance of being stuck.

YAML not well formed

Im trying to stop my RDS cluster using Lambda and Im using the python code shown below to target the clusters tags,
However when trying to create a CloudFormation stack, AWS is telling me that the YAML is not well formed.
For example it tells that Key = 'True' is not well formed.
If I remove that bit of code it then tells me the next bit of code, client = boto3.client('rds') is also not well formed.
Ive put this code into an online python validator and it didnt report any issues.
Can anyone help with this? Thanks
Tag = 'AutoPower'
Key = 'True'
client = boto3.client('rds')
response = client.describe_db_cluster()
for resp in response['DBCluster']:
db_cluster_arn = resp['DBClusterArn']
response = client.list_tags_for_resource(ResourceName=db_cluster_arn)
for tags in response['TagList']:
if tags['Key'] == str(Key) and tags['Value'] == str(Tag):
status = resp['DBClusterStatus']
ClusterID = resp['DBClusterIdentifier']
print(InstanceID)
if status == 'available':
print("shutting down %s " % ClusterID)
client.stop_db_cluster(DBClusterIdentifier=ClusterID)
# print ("Do something with it : %s" % db_instance_arn)
elif status == 'stopped':
print("starting up %s " % ClusterID)
client.start_db_cluster(DBClusterIdentifier=ClusterID)
else:
print("The database is " + status + " status!")
You are most likely having issues due to indentation issues between CloudFormation YAML and the inline function code. Here is an example of CloudFormation YAML that would use your code inline to create a function:
Resources:
YourFunction:
Type: AWS::Lambda::Function
Properties:
Code:
ZipFile: |
Tag = 'AutoPower'
Key = 'True'
client = boto3.client('rds')
response = client.describe_db_cluster()
for resp in response['DBCluster']:
db_cluster_arn = resp['DBClusterArn']
response = client.list_tags_for_resource(ResourceName=db_cluster_arn)
for tags in response['TagList']:
if tags['Key'] == str(Key) and tags['Value'] == str(Tag):
status = resp['DBClusterStatus']
ClusterID = resp['DBClusterIdentifier']
print(InstanceID)
if status == 'available':
print("shutting down %s " % ClusterID)
client.stop_db_cluster(DBClusterIdentifier=ClusterID)
# print ("Do something with it : %s" % db_instance_arn)
elif status == 'stopped':
print("starting up %s " % ClusterID)
client.start_db_cluster(DBClusterIdentifier=ClusterID)
else:
print("The database is " + status + " status!")
Handler: index.lambda_handler
Role: !Sub "arn:aws:iam::${AWS::AccountId}:role/YourRoleNameHere"
Runtime: python3.9
Description: This is an example of a function in CloudFormation
FunctionName: Your_Function_Name
MemorySize: 128
Timeout: 180

How to mock functionality of boto3 module using pytest

I have a custom module written called sqs.py. The script will do the following:
Get a message from AWS SQS
Get the AWS S3 path to delete
Delete the path
Send a confirmation email to the user
I'm trying to write unit tests for this module that will verify the code will execute as expected and that it will raise exceptions when they do occur.
This means I will need to mock the response from Boto3 calls that I make. My problem is that the code will first establish the SQS client to obtain the message and then a second call to establish the S3 client. I'm not sure how to mock these 2 independent calls and be able to fake a response so I can test my script's functionality. Perhaps my approach is incorrect. At any case, any advice on how to do this properly is appreciated.
Here's how the code looks like:
import boto3
import json
import os
import pprint
import time
import asyncio
import logging
from send_email import send_email
queue_url = 'https://xxxx.queue.amazonaws.com/1234567890/queue'
def shutdown(message):
""" Sends shutdown command to OS """
os.system(f'shutdown +5 "{message}"')
def send_failure_email(email_config: dict, error_message: str):
""" Sends email notification to user with error message attached. """
recipient_name = email_config['recipient_name']
email_config['subject'] = 'Subject: Restore Failed'
email_config['message'] = f'Hello {recipient_name},\n\n' \
+ 'We regret that an error has occurred during the restore process. ' \
+ 'Please try again in a few minutes.\n\n' \
+ f'Error: {error_message}.\n\n' \
try:
send_email(email_config)
except RuntimeError as error_message:
logging.error(f'ERROR: cannot send email to user. {error_message}')
async def restore_s3_objects(s3_client: object, p_bucket_name: str, p_prefix: str):
"""Attempts to restore objects specified by p_bucket_name and p_prefix.
Returns True if restore took place, false otherwise.
"""
is_truncated = True
key_marker = None
key = ''
number_of_items_restored = 0
has_restore_occured = False
logging.info(f'performing restore for {p_bucket_name}/{p_prefix}')
try:
while is_truncated == True:
if not key_marker:
version_list = s3_client.list_object_versions(
Bucket = p_bucket_name,
Prefix = p_prefix)
else:
version_list = s3_client.list_object_versions(
Bucket = p_bucket_name,
Prefix = p_prefix,
KeyMarker = key_marker)
if 'DeleteMarkers' in version_list:
logging.info('found delete markers')
delete_markers = version_list['DeleteMarkers']
for d in delete_markers:
if d['IsLatest'] == True:
key = d['Key']
version_id = d['VersionId']
s3_client.delete_object(
Bucket = p_bucket_name,
Key = key,
VersionId = version_id
)
number_of_items_restored = number_of_items_restored + 1
is_truncated = version_list['IsTruncated']
logging.info(f'is_truncated: {is_truncated}')
if 'NextKeyMarker' in version_list:
key_marker = version_list['NextKeyMarker']
if number_of_items_restored > 0:
has_restore_occured = True
return has_restore_occured
except Exception as error_message:
raise RuntimeError(error_message)
async def main():
if 'AWS_ACCESS_KEY_ID' in os.environ \
and 'AWS_SECRET_ACCESS_KEY' in os.environ \
and os.environ['AWS_ACCESS_KEY_ID'] != '' \
and os.environ['AWS_SECRET_ACCESS_KEY'] != '':
sqs_client = boto3.client(
'sqs',
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
verify=False
)
s3_client = boto3.client(
's3',
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'],
verify=False
)
else:
sqs_client = boto3.client(
'sqs',
verify=False,
)
s3_client = boto3.client(
's3',
verify=False,
)
received_message = sqs_client.receive_message(
QueueUrl=queue_url,
AttributeNames=['All'],
VisibilityTimeout=10,
WaitTimeSeconds=20, # Wait up to 20 seconds for a message to arrive
)
if 'Messages' in received_message \
and len(received_message['Messages']) > 0:
# NOTE: Initialize email configuration
receipient_email = 'support#example.com'
username = receipient_email.split('#')[0]
fullname_length = len(username.split('.'))
fullname = f"{username.split('.')[0]}" # Group name / First name only
if (fullname_length == 2): # First name and last name available
fullname = f"{username.split('.')[0]} {username.split('.')[1]}"
fullname = fullname.title()
email_config = {
'destination': receipient_email,
'recipient_name': fullname,
'subject': 'Subject: Restore Complete',
'message': ''
}
try:
receipt_handle = received_message['Messages'][0]['ReceiptHandle']
except Exception as error_message:
logging.error(error_message)
send_failure_email(email_config, error_message)
shutdown(f'{error_message}')
try:
data = received_message['Messages'][0]['Body']
data = json.loads(data)
logging.info('A SQS message for a restore has been received.')
except Exception as error_message:
message = f'Unable to obtain and parse message body. {error_message}'
logging.error(message)
send_failure_email(email_config, message)
shutdown(f'{error_message}')
try:
bucket = data['bucket']
prefix = data['prefix']
except Exception as error_message:
message = f'Retrieving bucket name and prefix failed. {error_message}'
logging.error(message)
send_failure_email(email_config, message)
shutdown(f'{error_message}')
try:
logging.info(f'Initiating restore for path: {bucket}/{prefix}')
restore_was_performed = await asyncio.create_task(restore_s3_objects(s3_client, bucket, prefix))
if restore_was_performed is True:
email_config['message'] = f'Hello {fullname},\n\n' \
+ f'The files in the path \'{bucket}/{prefix}\' have been restored. ' \
send_email(email_config)
logging.info('Restore complete. Shutting down.')
else:
logging.info('Path does not require restore. Shutting down.')
shutdown(f'shutdown +5 "Restore successful! System will shutdown in 5 mins"')
except Exception as error_message:
message = f'File restoration failed. {error_message}'
logging.error(message)
send_failure_email(email_config, message)
shutdown(f'{error_message}')
try:
sqs_client.delete_message(
QueueUrl=queue_url,
ReceiptHandle=receipt_handle,
)
except Exception as error_message:
message = f'Deleting restore session from SQS failed. {error_message}'
logging.error(message)
send_failure_email(email_config, message)
shutdown(f'{error_message}')
if __name__ == '__main__':
logging.basicConfig(filename='restore.log',level=logging.INFO)
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
loop.close()
The only way I was able to mock Boto3 is rebuilding a small class that represents the actual method structure. This is because Boto3 uses dynamic methods and all the resource level methods are created at runtime.
This might not be industry standard but I wasn't able to get any of the methods I found on the internet to work most of the time and this worked pretty well for me and requires minimal effort (comparing to some of the solutions I found).
class MockClient:
def __init__(self, region_name, aws_access_key_id, aws_secret_access_key):
self.region_name = region_name
self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.MockS3 = MockS3()
def client(self, service_name, **kwargs):
return self.MockS3
class MockS3:
def __init__(self):
self.response = None # Test your mock data from S3 here
def list_object_versions(self, **kwargs):
return self.response
class S3TestCase(unittest.TestCase):
def test_restore_s3_objects(self):
# Given
bucket = "testBucket" # Test this to something that somewahat realistic
prefix = "some/prefix" # Test this to something that somewahat realistic
env_vars = mock.patch.dict(os.environ, {"AWS_ACCESS_KEY_ID": "abc",
"AWS_SECRET_ACCESS_KEY": "def"})
env_vars.start()
# initialising the Session can be tricy since it has to be imported from
# the module/file that creates the session on actual code rather than
# where's a Session code is. In this case you might have to import from
# main rather than boto3.
boto3.session.Session = mock.Mock(side_effect=[
MockClient(region_name='eu-west-1',
aws_access_key_id=os.environ['AWS_ACCESS_KEY_ID'],
aws_secret_access_key=os.environ['AWS_SECRET_ACCESS_KEY'])])
s3_client = boto3.client('s3', verify=False)
# When
has_restore_occured = restore_s3_objects(s3_client, bucket, prefix)
# Then
self.assertEqual(has_restore_occured, False) # your expected result set
env_vars.stop()

rq.job.Job returning a 'sh: 1: mv: not found' error

I currently have a master python script which launches 6 jobs on remote hosts, and polls whether the jobs are done or not over a long period (days, usually). However, in my code below, the first element in the self.job_results list is always ''sh: 1: mv: not found'. However, the 6 job values always are in that list (e.g. there are 7 elements in the list, and there should only be 6). It appears that rq.job.Job is returning this value; any idea why?
hosts = HOSTS.keys()
job_ids = []
for host in hosts:
r = requests.get(HOSTS[host] + 'launch_jobs', auth=('admin', 'secret'))
job_ids.append(r.text)
host_job_dict = dict(zip(hosts, job_ids))
print "HOST_JOB_DICT: %s " % host_job_dict
launch_time = datetime.datetime.now()
self.job_result = []
complete = False
status = [False]*len(hosts)
host_job_keys = host_job_dict.keys()
while not complete:
check_time = datetime.datetime.now()
time_diff = check_time - launch_time
if time_diff.seconds > JOB_TIMEOUT:
sys.exit('Job polling has lasted 10 days, something is wrong')
print "HOST_JOB_KEYS %s " % host_job_keys
for idx, key in enumerate(host_job_keys):
if not status[idx]:
host = HOSTS[key]
j_id = host_job_dict[key]
req = requests.get(host + 'check_job/' + j_id, auth=('admin', 'secret'))
if req.status_code == 202:
continue
elif req.status_code == 200:
self.job_result.append(req.json()['results'].encode('ascii').split())
status[idx] = True
complete = all(status)
time.sleep(1)
And on the server side of things...
#app.route("/check_job/<job_key>", methods=['GET'])
#requires_auth
def check_job(job_key):
job = Job.fetch(job_key, connection=conn)
if job.is_finished:
data = job.return_value
json_data = jsonify({"results": data})
# return Response(response=json_data, status=200, mimetype="application/json")
return json_data
elif job.status == 'failed':
return "Failed", 202
else:
return "Not yet", 202
This turned out to be an extremely convoluted issue where mv and other commands in /bin aren't being recognized. To get around this, we just were explicit and used /bin/mvinstead. We believe this issue cropped up as a result of a complication from a systemctl instantiation

Categories

Resources