using Google Download Operator in Airflow with mutiple fies - python

I have a big query table that I need to bring down and populate a MSSQL table with. Since I can't find a BigQuerytoMSSQL operator, I'm doing this by hand.
I've been able to export the table to a series of <>_001.txt, <>_002.txt, etc, and store them into GCS, but now I need to get them down into the Airflow server.
I'm attempting to use the GoogleDownloadOperator, but it seams to have an issue I cannot repair.
Export_to_Local = GoogleCloudStorageDownloadOperator(
task_id='Export_GCS_to_Airflow_Staging',
bucket='offrs',
object='TAX_ASSESSOR_LIVE_*.txt',
filename=Variable.get("temp_directory") + "TAL/*",
google_cloud_storage_conn_id='GCP_Mother_Staging',
dag=dag
)
The above code results in this errror:
google.resumable_media.common.InvalidResponse: ('Request failed with status code', 404, 'Expected one of', <HTTPStatus.OK: 200>, <HTTPStatus.PARTIAL_CONTENT: 206>)
am I missing something? I don't know what the problem is.
THanks

GoogleCloudStorageDownloadOperator does not support wildcards, unfortunately.
The quickest option would be to use gsutil command in BashOperator if your VM is already authorized to that bucket.
The other option is to use the following Custom Operator:
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.exceptions import AirflowException
WILDCARD = '*'
class CustomGcsDownloadOperator(BaseOperator):
template_fields = ('source_bucket', 'source_object', 'destination_folder',
'destination_object',)
ui_color = '#f0eee4'
#apply_defaults
def __init__(self,
source_bucket,
source_object,
destination_folder,
destination_object=None,
google_cloud_storage_conn_id='google_cloud_default',
delegate_to=None,
last_modified_time=None,
*args,
**kwargs):
super(CustomGcsDownloadOperator,
self).__init__(*args, **kwargs)
self.source_bucket = source_bucket
self.source_object = source_object
self.destination_folder = destination_folder
self.destination_object = destination_object
self.google_cloud_storage_conn_id = google_cloud_storage_conn_id
self.delegate_to = delegate_to
self.last_modified_time = last_modified_time
def execute(self, context):
hook = GoogleCloudStorageHook(
google_cloud_storage_conn_id=self.google_cloud_storage_conn_id,
delegate_to=self.delegate_to
)
if WILDCARD in self.source_object:
total_wildcards = self.source_object.count(WILDCARD)
if total_wildcards > 1:
error_msg = "Only one wildcard '*' is allowed in source_object parameter. " \
"Found {} in {}.".format(total_wildcards, self.source_object)
raise AirflowException(error_msg)
prefix, delimiter = self.source_object.split(WILDCARD, 1)
objects = hook.list(self.source_bucket, prefix=prefix, delimiter=delimiter)
for source_object in objects:
if self.destination_object is None:
destination_object = source_object
else:
destination_object = source_object.replace(prefix,
self.destination_object, 1)
self._download_single_object(hook=hook, source_object=source_object,
destination_object=destination_object)
else:
self._download_single_object(hook=hook, source_object=self.source_object,
destination_object=self.destination_object)
def _download_single_object(self, hook, source_object, destination_object):
if self.last_modified_time is not None:
# Check to see if object was modified after last_modified_time
if hook.is_updated_after(self.source_bucket,
source_object,
self.last_modified_time):
self.log.debug("Object has been modified after %s ", self.last_modified_time)
pass
else:
return
self.log.info('Executing copy of gs://%s/%s to file://%s/%s',
self.source_bucket, source_object,
self.destination_folder, destination_object)
hook.download(self.source_bucket, source_object, destination_object)

Related

How do I set the python docusign_esign ApiClient to use a proxy?

I am using the following examples from the docusign site.
I have a set of python scripts that works well on my PC.
I have the move the code to a server behind a proxy.
I could not find any example or settings to configure a proxy.
I tired setting it in the underlining URLLIB3 code but it is being overwritten each time the AP creates class of the APIClient().
How do I set the python docusign_esign ApiClient to use a proxy?
Below is the portion of the code.
from docusign_esign import ApiClient
from docusign_esign import EnvelopesApi
from jwt_helper import get_jwt_token, get_private_key
# this one has all the connection parameters
from jwt_config import DS_JWT
import urllib3
proxy = urllib3.ProxyManager('http://<id>:<pwd>#<proxy_server>:3128/', maxsize=10)
# used by docusign to decide what you have access to
SCOPES = ["signature", "impersonation"]
# Call the envelope status change method to list the envelopes changed in the last 10 days
def worker(args):
api_client = ApiClient()
api_client.host = args['base_path']
api_client.set_default_header("Authorization", "Bearer " + args['access_token'])
envelope_api = EnvelopesApi(api_client)
# The Envelopes::listStatusChanges method has many options
# The list status changes call requires at least a from_date OR
# a set of envelopeIds. Here we filter using a from_date.
# Here we set the from_date to filter envelopes for the last month
# Use ISO 8601 date format
from_date = (datetime.datetime.utcnow() - timedelta(days=120)).isoformat()
results = envelope_api.list_status_changes(args['account_id'], from_date=from_date)
return results, envelope_api
# Call request_jwt_user_token method
def get_token(private_key, api_client):
token_response = get_jwt_token(private_key, SCOPES, DS_JWT["authorization_server"],
DS_JWT["ds_client_id"], DS_JWT["ds_impersonated_user_id"])
access_token = token_response.access_token
# Save API account ID
user_info = api_client.get_user_info(access_token)
accounts = user_info.get_accounts()
api_account_id = accounts[0].account_id
base_path = accounts[0].base_uri + "/restapi"
return {"access_token": access_token, "api_account_id": api_account_id, "base_path":
base_path}
# bucket to keep track of token info
def get_args(api_account_id, access_token, base_path):
args = {
"account_id": api_account_id,
"base_path": base_path,
"access_token": access_token
}
return args
# start the actual code here create and then setup the object
api_client = ApiClient()
api_client.set_base_path(DS_JWT["authorization_server"])
api_client.set_oauth_host_name(DS_JWT["authorization_server"])
api_client.rest_client.pool_manager.proxy = proxy
api_client.rest_client.pool_manager.proxy.scheme = "http"
private_key = get_private_key(DS_JWT["private_key_file"]).encode("ascii").decode("utf-8")
jwt_values = get_token(private_key, api_client)
args = get_args(jwt_values["api_account_id"], jwt_values["access_token"], jwt_values["base_path"])
account_id = args["account_id"]
# return the envelope list and api_client object created to get it
results, envelope_api = worker(args)
print("We found " + str(results.result_set_size) + " sets of files")
for envelope in results.envelopes:
envelope_id = envelope.envelope_id
print("Extracting " + envelope_id)
# The SDK always stores the received file as a temp file you can not set the path for this
# Call the envelope get method
temp_file = envelope_api.get_document(account_id=account_id, document_id="archive",
envelope_id=envelope_id)
if temp_file:
print("File is here " + temp_file)
with zipfile.ZipFile(temp_file, 'r') as zip_ref:
zip_ref.extractall(extract_dir + envelope_id + "\\")
zip_ref.close()
print("Done extracting " + envelope_id + " deleting zip file")
os.remove(temp_file)
print("Deleted file here " + temp_file)
else:
print("Failed to get data for " + envelope_id)

openai.error.InvalidRequestError: Engine not found

Tried accessing the OpenAPI example - Explain code
But it shows error as -
InvalidRequestError: Engine not found
enter code response = openai.Completion.create(
engine="code-davinci-002",
prompt="class Log:\n def __init__(self, path):\n dirname = os.path.dirname(path)\n os.makedirs(dirname, exist_ok=True)\n f = open(path, \"a+\")\n\n # Check that the file is newline-terminated\n size = os.path.getsize(path)\n if size > 0:\n f.seek(size - 1)\n end = f.read(1)\n if end != \"\\n\":\n f.write(\"\\n\")\n self.f = f\n self.path = path\n\n def log(self, event):\n event[\"_event_id\"] = str(uuid.uuid4())\n json.dump(event, self.f)\n self.f.write(\"\\n\")\n\n def state(self):\n state = {\"complete\": set(), \"last\": None}\n for line in open(self.path):\n event = json.loads(line)\n if event[\"type\"] == \"submit\" and event[\"success\"]:\n state[\"complete\"].add(event[\"id\"])\n state[\"last\"] = event\n return state\n\n\"\"\"\nHere's what the above class is doing:\n1.",
temperature=0,
max_tokens=64,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=["\"\"\""]
)
I've been trying to access the engine named code-davinci-002 which is a private beta version engine. So without access it's not possible to access the engine. It seems only the GPT-3 models are of public usage. We need to need to join the OpenAI Codex Private Beta Waitlist in order to access Codex models through API.
Please note that your code is not very readable.
However, from the given error, I think it has to do with the missing colon : in the engine name.
Change this line from:
engine="code-davinci-002",
to
engine="code-davinci:002",
If you are using a finetuned model instead of an engine, you'd want to use model= instead of engine=.
response = openai.Completion.create(
model="<finetuned model>",
prompt=

TypeError: individual_success_handle() missing 4 required positional arguments: 'request_type', 'name', 'response_time', and 'response_length'

I have been trying to integrate locust with influx db and while doing so I have implement EventHook, facing error "TypeError: individual_success_handle() missing 4 required positional arguments: 'request_type', 'name', 'response_time', and 'response_length'" after running my locustfile.py The code is attached below:
from locust import HttpUser, task, TaskSet, events
import json
import datetime
import pytz
from influxdb import InfluxDBClient
import socket
hostname = socket.gethostname()
client = InfluxDBClient(host="localhost", port="8086")
client.switch_database('DemoDB')
def individual_success_handle(request_type, name, response_time, response_length, **kwargs):
SUCCESS_TEMPLATE = '[{"measurement": "%s","tags": {"hostname":"%s","requestName": "%s","requestType": "%s",' \
'"status":"%s"' \
'},"time":"%s","fields": {"responseTime": "%s","responseLength":"%s"}' \
'}]'
json_string = SUCCESS_TEMPLATE % (
"ResponseTable", hostname, name, request_type, "success", datetime.datetime.now(tz=pytz.UTC), response_time,
response_length)
client.write_points(json.loads(json_string), time_precision='ms')
def individual_fail_handle(request_type, name, response_time, response_length, exception, **kwargs):
FAIL_TEMPLATE = '[{"measurement": "%s","tags": {"hostname":"%s","requestName": "%s","requestType": "%s",' \
'"exception":"%s","status":"%s"' \
'},"time":"%s","fields": {"responseTime": "%s","responseLength":"%s"}' \
'}]'
json_string = FAIL_TEMPLATE % (
"ResponseTable", hostname, name, request_type, exception, "fail", datetime.datetime.now(tz=pytz.UTC),
response_time, response_length)
client.write_points(json.loads(json_string), time_precision='ms')
# my_event.add_listener(individual_success_handle);
# my_event.add_listener(individual_fail_handle);
events.request_success += individual_success_handle()
events.request_failure += individual_fail_handle()
class UserBehavior(TaskSet):
# def on_start(self):
""" on_start is called when a Locust start before
any task is scheduled
"""
#task(1)
def profile(self):
self.client.get("/help")
#task(2)
def profile(self):
self.client.get("/pilot")
class WebsiteUser(HttpUser):
tasks = [UserBehavior]
min_wait = 5000
max_wait = 9000
As I'm trying to push the values dynamically I'm not really sure what I should be giving in parameters.
Remove the parentheses here:
events.request_success += individual_success_handle()
events.request_failure += individual_fail_handle()
You end up calling the method right away (which crashes, because you don't supply any parameters, and even if it had worked you would have added its return value to the success/failure hook, instead of the method itself)
You should be able to just copy paste from
https://github.com/SvenskaSpel/locust-plugins/blob/master/locust_plugins/listeners.py
Just change so that you write to Influx instead of Postgres.
You'll want to buffer your events and send them in batches, as Influx (and even your load gen) might not be happy with thousands of requests/s. That's also implemented in locust-plugins.

How can we list all the parameters in the aws parameter store using Boto3? There is no ssm.list_parameters in boto3 documentation?

SSM — Boto 3 Docs 1.9.64 documentation
get_parameters doesn't list all parameters?
For those who wants to just copy-paste the code:
import boto3
ssm = boto3.client('ssm')
parameters = ssm.describe_parameters()['Parameters']
Beware of the limit of max 50 parameters!
This code will get all parameters, by recursively fetching until there are no more (50 max is returned per call):
import boto3
def get_resources_from(ssm_details):
results = ssm_details['Parameters']
resources = [result for result in results]
next_token = ssm_details.get('NextToken', None)
return resources, next_token
def main()
config = boto3.client('ssm', region_name='us-east-1')
next_token = ' '
resources = []
while next_token is not None:
ssm_details = config.describe_parameters(MaxResults=50, NextToken=next_token)
current_batch, next_token = get_resources_from(ssm_details)
resources += current_batch
print(resources)
print('done')
You can use get_paginator api. find below example, In my use case i had to get all the values of SSM parameter store and wanted to compare it with a string.
import boto3
import sys
LBURL = sys.argv[1].strip()
client = boto3.client('ssm')
p = client.get_paginator('describe_parameters')
paginator = p.paginate().build_full_result()
for page in paginator['Parameters']:
response = client.get_parameter(Name=page['Name'])
value = response['Parameter']['Value']
if LBURL in value:
print("Name is: " + page['Name'] + " and Value is: " + value)
One of the responses from above/below(?) (by Val Lapidas) inspired me to expand it to this (as his solution doesn't get the SSM parameter value, and some other, additional details).
The downside here is that the AWS function client.get_parameters() only allows 10 names per call.
There's one referenced function call in this code (to_pdatetime(...)) that I have omitted - it just takes the datetime value and makes sure it is a "naive" datetime. This is because I am ultimately dumping this data to an Excel file using pandas, which doesn't deal well with timezones.
from typing import List, Tuple
from boto3 import session
from mypy_boto3_ssm import SSMClient
def ssm_params(aws_session: session.Session = None) -> List[dict]:
"""
Return a detailed list of all the SSM parameters.
"""
# -------------------------------------------------------------
#
#
# -------------------------------------------------------------
def get_parameter_values(ssm_client: SSMClient, ssm_details: dict) -> Tuple[list, str]:
"""
Retrieve additional attributes for the SSM parameters contained in the 'ssm_details'
dictionary passed in.
"""
# Get the details
ssm_param_details = ssm_details['Parameters']
# Just the names, ma'am
param_names = [result['Name'] for result in ssm_param_details]
# Get the parames, including the values
ssm_params_with_values = ssm_client.get_parameters(Names=param_names,
WithDecryption=True)
resources = []
result: dict
for result in ssm_params_with_values['Parameters']:
# Get the matching parameter from the `ssm_details` dict since this has some of the fields
# that aren't in the `ssm_params_with_values` returned from "get_arameters".
param_details = next((zz for zz in ssm_param_details if zz.get('Name', None) == result['Name']), {})
param_policy = param_details.get('Policies', None)
if len(param_policy) == 0:
param_policy = None
resources.append({
'Name': result['Name'],
'LastModifiedDate': to_pdatetime(result['LastModifiedDate']),
'LastModifiedUser': param_details.get('LastModifiedUser', None),
'Version': result['Version'],
'Tier': param_details.get('Tier', None),
'Policies': param_policy,
'ARN': result['ARN'],
'DataType': result.get('DataType', None),
'Type': result.get('Type', None),
'Value': result.get('Value', None)
})
next_token = ssm_details.get('NextToken', None)
return resources, next_token
# -------------------------------------------------------------
#
#
# -------------------------------------------------------------
if aws_session is None:
raise ValueError('No session.')
# Create SSM client
aws_ssm_client = aws_session.client('ssm')
next_token = ' '
ssm_resources = []
while next_token is not None:
# The "describe_parameters" call gets a whole lot of info on the defined SSM params,
# except their actual values. Due to this limitation let's call the nested function
# to get the values, and a few other details.
ssm_descriptions = aws_ssm_client.describe_parameters(MaxResults=10,
NextToken=next_token)
# This will get additional details for the params, including values.
current_batch, next_token = get_parameter_values(ssm_client=aws_ssm_client,
ssm_details=ssm_descriptions)
ssm_resources += current_batch
print(f'SSM Parameters: {len(ssm_resources)}')
return ssm_resources
pythonawsboto3amazon-web-services
There's no ListParameters only DescribeParameter, which lists all the paremeters, or you can set filters.
Boto3 Docs Link:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#SSM.Client.describe_parameters
AWS API Documentation Link:
https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_DescribeParameters.html
You can use get_parameters() and get_parameters_by_path().
Use paginators.
paginator = client.get_paginator('describe_parameters')
More information here.

google dataflow read from spanner

I am trying to read a table from a Google spanner database, and write it to a text file to do a backup, using google dataflow with the python sdk.
I have written the following script:
from __future__ import absolute_import
import argparse
import itertools
import logging
import re
import time
import datetime as dt
import logging
import apache_beam as beam
from apache_beam.io import iobase
from apache_beam.io import WriteToText
from apache_beam.io.range_trackers import OffsetRangeTracker, UnsplittableRangeTracker
from apache_beam.metrics import Metrics
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions, SetupOptions
from apache_beam.options.pipeline_options import GoogleCloudOptions
from google.cloud.spanner.client import Client
from google.cloud.spanner.keyset import KeySet
BUCKET_URL = 'gs://my_bucket'
OUTPUT = '%s/output/' % BUCKET_URL
PROJECT_ID = 'my_project'
INSTANCE_ID = 'my_instance'
DATABASE_ID = 'my_db'
JOB_NAME = 'spanner-backup'
TABLE = 'my_table'
class SpannerSource(iobase.BoundedSource):
def __init__(self):
logging.info('Enter __init__')
self.spannerOptions = {
"id": PROJECT_ID,
"instance": INSTANCE_ID,
"database": DATABASE_ID
}
self.SpannerClient = Client
def estimate_size(self):
logging.info('Enter estimate_size')
return 1
def get_range_tracker(self, start_position=None, stop_position=None):
logging.info('Enter get_range_tracker')
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = OffsetRangeTracker.OFFSET_INFINITY
range_tracker = OffsetRangeTracker(start_position, stop_position)
return UnsplittableRangeTracker(range_tracker)
def read(self, range_tracker): # This is not called when using the dataflowRunner !
logging.info('Enter read')
# instantiate spanner client
spanner_client = self.SpannerClient(self.spannerOptions["id"])
instance = spanner_client.instance(self.spannerOptions["instance"])
database = instance.database(self.spannerOptions["database"])
# read from table
table_fields = database.execute_sql("SELECT t.column_name FROM information_schema.columns AS t WHERE t.table_name = '%s'" % TABLE)
table_fields.consume_all()
self.columns = [x[0] for x in table_fields]
keyset = KeySet(all_=True)
results = database.read(table=TABLE, columns=self.columns, keyset=keyset)
# iterator over rows
results.consume_all()
for row in results:
JSON_row = {
self.columns[i]: row[i] for i in range(len(self.columns))
}
yield JSON_row
def split(self, start_position=None, stop_position=None):
# this should not be called since the source is unspittable
logging.info('Enter split')
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = 1
# Because the source is unsplittable (for now), only a single source is returned
yield iobase.SourceBundle(
weight=1,
source=self,
start_position=start_position,
stop_position=stop_position)
def run(argv=None):
"""Main entry point"""
pipeline_options = PipelineOptions()
google_cloud_options = pipeline_options.view_as(GoogleCloudOptions)
google_cloud_options.project = PROJECT_ID
google_cloud_options.job_name = JOB_NAME
google_cloud_options.staging_location = '%s/staging' % BUCKET_URL
google_cloud_options.temp_location = '%s/tmp' % BUCKET_URL
#pipeline_options.view_as(StandardOptions).runner = 'DirectRunner'
pipeline_options.view_as(StandardOptions).runner = 'DataflowRunner'
p = beam.Pipeline(options=pipeline_options)
output = p | 'Get Rows from Spanner' >> beam.io.Read(SpannerSource())
iso_datetime = dt.datetime.now().replace(microsecond=0).isoformat()
output | 'Store in GCS' >> WriteToText(file_path_prefix=OUTPUT + iso_datetime + '-' + TABLE, file_name_suffix='') # if this line is commented, job completes but does not do anything
result = p.run()
result.wait_until_finish()
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()
However, this script runs correctly only on the DirectRunner: when I let it run on the DataflowRunner, it runs for a while without any output, before exiting with an error:
"Executing failure step failure14 [...] Workflow failed. Causes: [...] The worker lost contact with the service."
Sometimes, it just goes on forever, without creating an output.
Moreover, if I comment the line 'output = ...', the job completes, but without actually reading the data.
It also appears that the dataflowRunner calls the function 'estimate_size' of the source, but not the functions 'read' or 'get_range_tracker'.
Does anyone have any ideas about what may cause this ?
I know there is a (more complete) java SDK with an experimental spanner source/sink available, but if possible I'd rather stick with python.
Thanks
Google currently added support of Backup Spanner with Dataflow, you can choose related template when creating DataFlow job.
For more: https://cloud.google.com/blog/products/gcp/cloud-spanner-adds-import-export-functionality-to-ease-data-movement
I have reworked my code following the suggestion to simply use a ParDo, instead of using the BoundedSource class. As a reference, here is my solution; I am sure there are many ways to improve on it, and I would be happy to to hear opinions.
In particular I am surprised that I have to a create a dummy PColl when starting the pipeline (if I don't, I get an error
AttributeError: 'PBegin' object has no attribute 'windowing'
that I could not work around. The dummy PColl feels a bit like a hack.
from __future__ import absolute_import
import datetime as dt
import logging
import apache_beam as beam
from apache_beam.io import WriteToText
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions, SetupOptions
from apache_beam.options.pipeline_options import GoogleCloudOptions
from google.cloud.spanner.client import Client
from google.cloud.spanner.keyset import KeySet
BUCKET_URL = 'gs://my_bucket'
OUTPUT = '%s/some_folder/' % BUCKET_URL
PROJECT_ID = 'my_project'
INSTANCE_ID = 'my_instance'
DATABASE_ID = 'my_database'
JOB_NAME = 'my_jobname'
class ReadTables(beam.DoFn):
def __init__(self, project, instance, database):
super(ReadTables, self).__init__()
self._project = project
self._instance = instance
self._database = database
def process(self, element):
# get list of tables in the database
table_names_row = Client(self._project).instance(self._instance).database(self._database).execute_sql('SELECT t.table_name FROM information_schema.tables AS t')
for row in table_names_row:
if row[0] in [u'COLUMNS', u'INDEXES', u'INDEX_COLUMNS', u'SCHEMATA', u'TABLES']: # skip these
continue
yield row[0]
class ReadSpannerTable(beam.DoFn):
def __init__(self, project, instance, database):
super(ReadSpannerTable, self).__init__()
self._project = project
self._instance = instance
self._database = database
def process(self, element):
# first read the columns present in the table
table_fields = Client(self._project).instance(self._instance).database(self._database).execute_sql("SELECT t.column_name FROM information_schema.columns AS t WHERE t.table_name = '%s'" % element)
columns = [x[0] for x in table_fields]
# next, read the actual data in the table
keyset = KeySet(all_=True)
results_streamed_set = Client(self._project).instance(self._instance).database(self._database).read(table=element, columns=columns, keyset=keyset)
for row in results_streamed_set:
JSON_row = { columns[i]: row[i] for i in xrange(len(columns)) }
yield (element, JSON_row) # output pairs of (table_name, data)
def run(argv=None):
"""Main entry point"""
pipeline_options = PipelineOptions()
pipeline_options.view_as(SetupOptions).save_main_session = True
pipeline_options.view_as(SetupOptions).requirements_file = "requirements.txt"
google_cloud_options = pipeline_options.view_as(GoogleCloudOptions)
google_cloud_options.project = PROJECT
google_cloud_options.job_name = JOB_NAME
google_cloud_options.staging_location = '%s/staging' % BUCKET_URL
google_cloud_options.temp_location = '%s/tmp' % BUCKET_URL
pipeline_options.view_as(StandardOptions).runner = 'DataflowRunner'
p = beam.Pipeline(options=pipeline_options)
init = p | 'Begin pipeline' >> beam.Create(["test"]) # have to create a dummy transform to initialize the pipeline, surely there is a better way ?
tables = init | 'Get tables from Spanner' >> beam.ParDo(ReadTables(PROJECT, INSTANCE_ID, DATABASE_ID)) # read the tables in the db
rows = (tables | 'Get rows from Spanner table' >> beam.ParDo(ReadSpannerTable(PROJECT, INSTANCE_ID, DATABASE_ID)) # for each table, read the entries
| 'Group by table' >> beam.GroupByKey()
| 'Formatting' >> beam.Map(lambda (table_name, rows): (table_name, list(rows)))) # have to force to list here (dataflowRunner produces _Unwindowedvalues)
iso_datetime = dt.datetime.now().replace(microsecond=0).isoformat()
rows | 'Store in GCS' >> WriteToText(file_path_prefix=OUTPUT + iso_datetime, file_name_suffix='')
result = p.run()
result.wait_until_finish()
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()

Categories

Resources