dask.distributed: handle serialization of exotic objects? - python

Context
I am trying to write a data pipeline using dask distributed and some legacy code from a previous project. get_data simply get url:str and session:ClientSession as arguments and return a pandas DataFrame.
from dask.distributed import Client
from aiohttp import ClientSession
client = Client()
session: ClientSession = connector.session_factory()
futures = client.map(
get_data, # function to get data (takes url and http session)
urls,
[session for _ in range(len(urls))], # PROBLEM IS HERE
retries=5,
)
r = client.map(loader.job, futures)
_ = client.gather(r)
Problem
I get the following error
File "/home/zar3bski/.cache/pypoetry/virtualenvs/poc-dask-iG-N0GH5-py3.10/lib/python3.10/site-packages/distributed/worker.py", line 2952, in warn_dumps
b = dumps(obj)
File "/home/zar3bski/.cache/pypoetry/virtualenvs/poc-dask-iG-N0GH5-py3.10/lib/python3.10/site-packages/distributed/protocol/pickle.py", line 58, in dumps
result = cloudpickle.dumps(x, **dump_kwargs)
File "/home/zar3bski/.cache/pypoetry/virtualenvs/poc-dask-iG-N0GH5-py3.10/lib/python3.10/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
cp.dump(obj)
File "/home/zar3bski/.cache/pypoetry/virtualenvs/poc-dask-iG-N0GH5-py3.10/lib/python3.10/site-packages/cloudpickle/cloudpickle_fast.py", line 632, in dump
return Pickler.dump(self, obj)
TypeError: cannot pickle 'TaskStepMethWrapper' object
Unclosed client session
client_session: <aiohttp.client.ClientSession object at 0x7f3042b2fa00>
My temptation was then to register a serializer and a deserializer for this exotic object following this doc
from distributed.protocol import dask_serialize, dask_deserialize
#dask_serialize.register(TaskStepMethWrapper)
def serialize(ctx: TaskStepMethWrapper) -> Tuple[Dict, List[bytes]]:
header = {} #?
frames = [] #?
return header, frames
#dask_deserialize.register(TaskStepMethWrapper)
def deserialize(header: Dict, frames: List[bytes]) -> TaskStepMethWrapper:
return TaskStepMethWrapper(frames) #?
The problem is that I don't know where to load TaskStepMethWrapper from. I know that class TaskStepMethWrapper is asyncio related
grep -rnw './' -e '.*TaskStepMethWrapper.*'
grep: ./lib-dynload/_asyncio.cpython-310-x86_64-linux-gnu.so : fichiers binaires correspondent
But I couldn't find its definition anywhere in site-packages/aiohttp. I also tried to use a Client(asynchronous=True) with only resulted in a TypeError: cannot pickle '_contextvars.Context' object.
How do you handle exotic objects serializations in dask. Should I extend the dask serializer or use an additional serialization family?
client = Client('tcp://scheduler-address:8786',
serializers=['dask', 'pickle'], # BUT WHICH ONE
deserializers=['dask', 'msgpack']) # BUT WHICH ONE

There is a far easier to get around this: create your sessions within the mapped function. You would have been recreating the sessions in each worker anyway, they cannot survive a transfer
from dask.distributed import Client
from aiohttp import ClientSession
client = Client()
def func(u):
session: ClientSession = connector.session_factory()
return get_data(u, session)
futures = client.map(
func,
urls,
retries=5,
)
(I don't know what loader.job is, so I have omitted that).
Note that TaskStepMethWrapper (and anything to do with aiohttp) sounds like it should be called only in async code. Maybe func needs to be async and you need appropriate awaits.

Related

Gateway Time-out with StreamingResponse and custom Middleware fastapi [duplicate]

We are writing a web service using Python FastAPI that is going to be hosted in Kubernetes. For auditing purposes, we need to save the raw JSON body of the request/response for specific routes. The body size of both request and response JSON is about 1MB, and preferably, this should not impact the response time.
How can we do that?
Option 1 - Using Middleware
You could use a Middleware. A middleware takes each request that comes to your application, and hence, allows you to handle the request before it is processed by any specific endpoint, as well as the response, before it is returned to the client. To create a middleware, you use the decorator #app.middleware("http") on top of a function, as shown below. As you need to consume the request body from the stream inside the middleware—using either request.body() or request.stream(), as shown in this answer (behind the scenes, the former method actually calls the latter, see here)—then it won't be available when you later pass the request to the corresponding endpoint. Thus, you can follow the approach described in this post to make the request body available down the line (i.e., using the set_body function below). As for the response body, you can use the same approach as described in this answer to consume the body and then return the response to the client. Either option described in the aforementioned linked answer would work; the below, however, uses Option 2, which stores the body in a bytes object and returns a custom Response directly (along with the status_code, headers and media_type of the original response).
To log the data, you could use a BackgroundTask, as described in this answer and this answer. A BackgroundTask will run only once the response has been sent (see Starlette documentation as well); thus, the client won't have to be waiting for the logging to complete before receiving the response (and hence, the response time won't be noticeably impacted).
Note
If you had a streaming request or response with a body that wouldn't fit into your server's RAM (for example, imagine a body of 100GB on a machine running 8GB RAM), it would become problematic, as you are storing the data to RAM, which wouldn't have enough space available to accommodate the accumulated data. Also, in case of a large response (e.g., a large FileResponse or StreamingResponse), you may be faced with Timeout errors on client side (or on reverse proxy side, if you are using one), as you would not be able to respond back to the client, until you have read the entire response body (as you are looping over response.body_iterator). You mentioned that "the body size of both request and response JSON is about 1MB"; hence, that should normally be fine (however, it is always a good practice to consider beforehand matters, such as how many requests your API is expected to be serving concurrently, what other applications might be using the RAM, etc., in order to rule whether this is an issue or not). If you needed to, you could limit the number of requests to your API endpoints using, for example, SlowAPI (as shown in this answer).
Limiting the usage of the middleware to specific routes only
You could limit the usage of the middleware to specific endpoints by:
checking the request.url.path inside the middleware against a
pre-defined list of routes for which you would like to log the
request and response, as described in this answer (see
"Update" section),
or using a sub application, as demonstrated in this
answer
or using a custom APIRoute class, as demonstrated in Option 2
below.
Working Example
from fastapi import FastAPI, APIRouter, Response, Request
from starlette.background import BackgroundTask
from fastapi.routing import APIRoute
from starlette.types import Message
from typing import Dict, Any
import logging
app = FastAPI()
logging.basicConfig(filename='info.log', level=logging.DEBUG)
def log_info(req_body, res_body):
logging.info(req_body)
logging.info(res_body)
async def set_body(request: Request, body: bytes):
async def receive() -> Message:
return {'type': 'http.request', 'body': body}
request._receive = receive
#app.middleware('http')
async def some_middleware(request: Request, call_next):
req_body = await request.body()
await set_body(request, req_body)
response = await call_next(request)
res_body = b''
async for chunk in response.body_iterator:
res_body += chunk
task = BackgroundTask(log_info, req_body, res_body)
return Response(content=res_body, status_code=response.status_code,
headers=dict(response.headers), media_type=response.media_type, background=task)
#app.post('/')
def main(payload: Dict[Any, Any]):
return payload
In case you would like to perform some validation on the request body—for example, ensruing that the request body size is not exceeding a certain value—instead of using request.body(), you can process the body one chunk at a time using the .stream() method, as shown below (similar to this answer).
#app.middleware('http')
async def some_middleware(request: Request, call_next):
req_body = b''
async for chunk in request.stream():
req_body += chunk
...
Option 2 - Using custom APIRoute class
You can alternatively use a custom APIRoute class—similar to here and here—which, among other things, would allow you to manipulate the request body before it is processed by your application, as well as the response body before it is returned to the client. This option also allows you to limit the usage of this class to the routes you wish, as only the endpoints under the APIRouter (i.e., router in the example below) will use the custom APIRoute class .
It should be noted that the same comments mentioned in Option 1 above, under the "Note" section, apply to this option as well. For example, if your API returns a StreamingResponse—such as in /video route of the example below, which is streaming a video file from an online source (public videos to test this can be found here, and you can even use a longer video than the one used below to see the effect more clearly)—you may come across issues on server side, if your server's RAM can't handle it, as well as delays on client side (and reverse proxy server, if using one) due to the whole (streaming) response being read and stored in RAM, before it is returned to the client (as explained earlier). In such cases, you could exclude such endpoints that return a StreamingResponse from the custom APIRoute class and limit its usage only to the desired routes—especially, if it is a large video file, or even live video that wouldn't likely make much sense to have it stored in the logs—simply by not using the #<name_of_router> decorator (i.e., #router in the example below) for such endpoints, but rather using the #<name_of_app> decorator (i.e., #app in the example below), or some other APIRouter or sub application.
Working Example
from fastapi import FastAPI, APIRouter, Response, Request
from starlette.background import BackgroundTask
from starlette.responses import StreamingResponse
from fastapi.routing import APIRoute
from starlette.types import Message
from typing import Callable, Dict, Any
import logging
import httpx
def log_info(req_body, res_body):
logging.info(req_body)
logging.info(res_body)
class LoggingRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
req_body = await request.body()
response = await original_route_handler(request)
if isinstance(response, StreamingResponse):
res_body = b''
async for item in response.body_iterator:
res_body += item
task = BackgroundTask(log_info, req_body, res_body)
return Response(content=res_body, status_code=response.status_code,
headers=dict(response.headers), media_type=response.media_type, background=task)
else:
res_body = response.body
response.background = BackgroundTask(log_info, req_body, res_body)
return response
return custom_route_handler
app = FastAPI()
router = APIRouter(route_class=LoggingRoute)
logging.basicConfig(filename='info.log', level=logging.DEBUG)
#router.post('/')
def main(payload: Dict[Any, Any]):
return payload
#router.get('/video')
def get_video():
url = 'https://storage.googleapis.com/gtv-videos-bucket/sample/ForBiggerBlazes.mp4'
def gen():
with httpx.stream('GET', url) as r:
for chunk in r.iter_raw():
yield chunk
return StreamingResponse(gen(), media_type='video/mp4')
app.include_router(router)
You may try to customize APIRouter like in FastAPI official documentation:
import time
from typing import Callable
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi.routing import APIRoute
class TimedRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
before = time.time()
response: Response = await original_route_handler(request)
duration = time.time() - before
response.headers["X-Response-Time"] = str(duration)
print(f"route duration: {duration}")
print(f"route response: {response}")
print(f"route response headers: {response.headers}")
return response
return custom_route_handler
app = FastAPI()
router = APIRouter(route_class=TimedRoute)
#app.get("/")
async def not_timed():
return {"message": "Not timed"}
#router.get("/timed")
async def timed():
return {"message": "It's the time of my life"}
app.include_router(router)
As the other answers did not work for me and I searched quite extensively on stackoverflow to fix this problem, I will show my solution below.
The main issue is that when using the request body or response body many of the approaches/solutions offered online do simply not work as the request/response body is consumed in reading it from the stream.
To solve this issue I adapted an approach that basically reconstructs the request and response after reading them. This is heavily based on the comment by user 'kovalevvlad' on https://github.com/encode/starlette/issues/495.
Custom middleware is created that is later added to the app to log all requests and responses. Note that you need some kind of logger to store your logs.
from json import JSONDecodeError
import json
import logging
from typing import Callable, Awaitable, Tuple, Dict, List
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import Scope, Message
# Set up your custom logger here
logger = ""
class RequestWithBody(Request):
"""Creation of new request with body"""
def __init__(self, scope: Scope, body: bytes) -> None:
super().__init__(scope, self._receive)
self._body = body
self._body_returned = False
async def _receive(self) -> Message:
if self._body_returned:
return {"type": "http.disconnect"}
else:
self._body_returned = True
return {"type": "http.request", "body": self._body, "more_body": False}
class CustomLoggingMiddleware(BaseHTTPMiddleware):
"""
Use of custom middleware since reading the request body and the response consumes the bytestream.
Hence this approach to basically generate a new request/response when we read the attributes for logging.
"""
async def dispatch( # type: ignore
self, request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]]
) -> Response:
# Store request body in a variable and generate new request as it is consumed.
request_body_bytes = await request.body()
request_with_body = RequestWithBody(request.scope, request_body_bytes)
# Store response body in a variable and generate new response as it is consumed.
response = await call_next(request_with_body)
response_content_bytes, response_headers, response_status = await self._get_response_params(response)
# Logging
# If there is no request body handle exception, otherwise convert bytes to JSON.
try:
req_body = json.loads(request_body_bytes)
except JSONDecodeError:
req_body = ""
# Logging of relevant variables.
logger.info(
f"{request.method} request to {request.url} metadata\n"
f"\tStatus_code: {response.status_code}\n"
f"\tRequest_Body: {req_body}\n"
)
# Finally, return the newly instantiated response values
return Response(response_content_bytes, response_status, response_headers)
async def _get_response_params(self, response: StreamingResponse) -> Tuple[bytes, Dict[str, str], int]:
"""Getting the response parameters of a response and create a new response."""
response_byte_chunks: List[bytes] = []
response_status: List[int] = []
response_headers: List[Dict[str, str]] = []
async def send(message: Message) -> None:
if message["type"] == "http.response.start":
response_status.append(message["status"])
response_headers.append({k.decode("utf8"): v.decode("utf8") for k, v in message["headers"]})
else:
response_byte_chunks.append(message["body"])
await response.stream_response(send)
content = b"".join(response_byte_chunks)
return content, response_headers[0], response_status[0]

How to download a large file using FastAPI?

I am trying to download a large file (.tar.gz) from FastAPI backend. On server side, I simply validate the filepath, and I then use Starlette.FileResponse to return the whole file—just like what I've seen in many related questions on StackOverflow.
Server side:
return FileResponse(path=file_name, media_type='application/octet-stream', filename=file_name)
After that, I get the following error:
File "/usr/local/lib/python3.10/dist-packages/fastapi/routing.py", line 149, in serialize_response
return jsonable_encoder(response_content)
File "/usr/local/lib/python3.10/dist-packages/fastapi/encoders.py", line 130, in jsonable_encoder
return ENCODERS_BY_TYPE[type(obj)](obj)
File "pydantic/json.py", line 52, in pydantic.json.lambda
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x8b in position 1: invalid start byte
I also tried using StreamingResponse, but got the same error. Any other ways to do it?
The StreamingResponse in my code:
#x.post("/download")
async def download(file_name=Body(), token: str | None = Header(default=None)):
file_name = file_name["file_name"]
# should be something like xx.tar
def iterfile():
with open(file_name,"rb") as f:
yield from f
return StreamingResponse(iterfile(),media_type='application/octet-stream')
Ok, here is an update to this problem.
I found the error did not occur on this api, but the api doing forward request of this.
#("/")
def f():
req = requests.post(url ="/download")
return req.content
And here if I returned a StreamingResponse with .tar file, it led to (maybe) encoding problems.
When using requests, remember to set the same media-type. Here is media_type='application/octet-stream'. And it works!
If you find yield from f being rather slow when using StreamingResponse with file-like objects, you could instead create a generator where you read the file in chunks using a specified chunk size; hence, speeding up the process. Examples can be found below.
Note that StreamingResponse can take either an async generator or a normal generator/iterator to stream the response body. In case you used the standard open() method that doesn't support async/await, you would have to declare the generator function with normal def. Regardless, FastAPI/Starlette will still work asynchronously, as it will check whether the generator you passed is asynchronous (as shown in the source code), and if is not, it will then run the generator in a separate thread, using iterate_in_threadpool, that is then awaited.
You can set the Content-Disposition header in the response (as described in this answer, as well as here and here) to indicate if the content is expected to be displayed inline in the browser (if you are streaming, for example, a .mp4 video, .mp3 audio file, etc), or as an attachment that is downloaded and saved locally (using the specified filename).
As for the media_type (also known as MIME type), there are two primary MIME types (see Common MIME types):
text/plain is the default value for textual files. A textual file should be human-readable and must not contain binary data.
application/octet-stream is the default value for all other cases. An unknown file type should use this type.
For a file with .tar extension, as shown in your question, you can also use a different subtype from octet-stream, that is, x-tar. Otherwise, if the file is of unknown type, stick to application/octet-stream. See the linked documentation above for a list of common MIME types.
Option 1 - Using normal generator
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
CHUNK_SIZE = 1024 * 1024 # = 1MB - adjust the chunk size as desired
some_file_path = 'large_file.tar'
app = FastAPI()
#app.get('/')
def main():
def iterfile():
with open(some_file_path, 'rb') as f:
while chunk := f.read(CHUNK_SIZE):
yield chunk
headers = {'Content-Disposition': 'attachment; filename="large_file.tar"'}
return StreamingResponse(iterfile(), headers=headers, media_type='application/x-tar')
Option 2 - Using async generator with aiofiles
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import aiofiles
CHUNK_SIZE = 1024 * 1024 # = 1MB - adjust the chunk size as desired
some_file_path = 'large_file.tar'
app = FastAPI()
#app.get('/')
async def main():
async def iterfile():
async with aiofiles.open(some_file_path, 'rb') as f:
while chunk := await f.read(CHUNK_SIZE):
yield chunk
headers = {'Content-Disposition': 'attachment; filename="large_file.tar"'}
return StreamingResponse(iterfile(), headers=headers, media_type='application/x-tar')

ConnectionAbortedError: [WinError 10053] when trying to connect to itself with a web app

I have just ran into a funny situation when testing my FastAPI Python application and thought it might be useful for some of the people who reuse sessions in their apps and want to test requests using the same app, but get stuck on weir errors like the one in the title.
Also I desire to know what is happening here.
Context
I have an async FastAPI application, that schedules multiple requests based on a unimportant configuration. After the list of request definitions is prepared, a session is created the requests are sent, possibly with delays so I can spread them in time.
To test if the requests are getting through, I have cretaed routes in my own app so I can send the testing requests back to my own application. The application basically talks to itself.
It was listening on 127.0.0.1:8000 at the time of testing.
I have following functions defined for building async tasks:
def optional_session(func):
async def wrapper(*args, **kwargs):
if 'session' not in kwargs or kwargs['session'] is None:
async with ClientSession() as session:
kwargs['session'] = session
return await func(*args, **kwargs)
else:
return await func(*args, **kwargs)
return wrapper
#optional_session
async def post_json_with_time_from_url(url: str, data: dict, session: ClientSession = None) -> Tuple[Union[dict, None], float]:
"""
A method that performs a request to a specified URL and reads the response as JSON data.
If the request is successful the data is returned. If an error occurs it is logged and the returned data is None.
:param data: data to send i the request
:param url: The URL to retrieve the image from
:return: A valid response or None
:param session:
"""
result = None, time.time()
try:
async with session.post(url, data=data) as response: # type: ClientResponse
# check if the response is valid
if response.status == 200:
try:
# we have to read the response before leaving the response context manager
result = await response.json(), time.time()
except Exception as e:
logger.error("...")
else:
logger.error(
"...")
except InvalidURL as e:
logger.error(f"...")
except Exception as e:
logger.error("...")
return result
def delay(func, seconds: int):
""""
This decorator adds a time delay to an async function.
"""
if seconds is None:
seconds = 0
async def wrapper(*args, **kwargs):
await asyncio.sleep(seconds)
return await func(*args, **kwargs)
return wrapper
def parse_get_post_request(config: ConfigContext, session: aiohttp.ClientSession = None) -> asyncio.Task:
"""
Parses the get/post request from the configuration dictionary and creates an async task for it.
"""
request_type = config.extract_key('request_type', True).lower()
delay_ = config.extract_key('delay')
url_base_ = config.extract_key('request_url_base', True)
url_suffix_ = config.extract_key('request_url_suffix', True)
url_ = urljoin(base=url_base_, url=url_suffix_)
if request_type == 'get':
return asyncio.ensure_future(
delay(get_json_with_time_from_url, delay_)(url=url_, session=session)
)
elif request_type == 'post':
return asyncio.ensure_future(
delay(post_json_with_time_from_url, delay_)(url=url_, session=session, data=config.extract_key('request_data'))
)
else:
raise ValueError(f"Unsupported request type: {request_type}")
I am creating an aiohttp session like this:
async with aiohttp.ClientSession() as session:
...
and then reusing it throughout the context code block somehting like this:
single_request_tasks = []
...
for config in configs:
single_request_tasks.append(parse_get_post_request(config=plan_config, session=session))
...
responses = await asyncio.gather(*single_request_tasks)
...
Problem
Somehow, when I send the requests altogether, and one of the requests arrives back to the app at the same time as another one, an exception is thrown:
ConnectionAbortedError: [WinError 10053] An established connection was aborted by the software in your host machine
It turns out, that for some reason, the session I share for all the requests is terminated when multiple requests arrive at the same time, using the same ClientSession instance.
I am not really sure why this happens exactly, apart from suspecting some port clash shanenigans,
but it is resolved, when I use separate session for each request or when I spread them in time with an interval of one second (for example)
Workaround
I have used separate sessions for each request when looping back to localhost.
I also avoided the issue, when I have spread the requests in time, so each one has time to complete before the other one is sent, but timing is not that reliable mechanism (since OS task scheduler, concurrency in asyncio, network latency, etc.)
This problem does not occur when sharing a session with a different host (for example when scraping images from imgur.com) so I believe the problem is related to the fact that I am looping back to the localhost.
Question
Why this happens exactly? Why is the session closed by the software in the situation I described?
Is there anything I am doing wrong with the session? How does Starlette handle loopback connections? Is this case-dependent and do I need to do more detective work somehow or is this a generally recognized, platform independent behaviour?

How to log raw HTTP request/response in Python FastAPI?

We are writing a web service using Python FastAPI that is going to be hosted in Kubernetes. For auditing purposes, we need to save the raw JSON body of the request/response for specific routes. The body size of both request and response JSON is about 1MB, and preferably, this should not impact the response time.
How can we do that?
Option 1 - Using Middleware
You could use a Middleware. A middleware takes each request that comes to your application, and hence, allows you to handle the request before it is processed by any specific endpoint, as well as the response, before it is returned to the client. To create a middleware, you use the decorator #app.middleware("http") on top of a function, as shown below. As you need to consume the request body from the stream inside the middleware—using either request.body() or request.stream(), as shown in this answer (behind the scenes, the former method actually calls the latter, see here)—then it won't be available when you later pass the request to the corresponding endpoint. Thus, you can follow the approach described in this post to make the request body available down the line (i.e., using the set_body function below). As for the response body, you can use the same approach as described in this answer to consume the body and then return the response to the client. Either option described in the aforementioned linked answer would work; the below, however, uses Option 2, which stores the body in a bytes object and returns a custom Response directly (along with the status_code, headers and media_type of the original response).
To log the data, you could use a BackgroundTask, as described in this answer and this answer. A BackgroundTask will run only once the response has been sent (see Starlette documentation as well); thus, the client won't have to be waiting for the logging to complete before receiving the response (and hence, the response time won't be noticeably impacted).
Note
If you had a streaming request or response with a body that wouldn't fit into your server's RAM (for example, imagine a body of 100GB on a machine running 8GB RAM), it would become problematic, as you are storing the data to RAM, which wouldn't have enough space available to accommodate the accumulated data. Also, in case of a large response (e.g., a large FileResponse or StreamingResponse), you may be faced with Timeout errors on client side (or on reverse proxy side, if you are using one), as you would not be able to respond back to the client, until you have read the entire response body (as you are looping over response.body_iterator). You mentioned that "the body size of both request and response JSON is about 1MB"; hence, that should normally be fine (however, it is always a good practice to consider beforehand matters, such as how many requests your API is expected to be serving concurrently, what other applications might be using the RAM, etc., in order to rule whether this is an issue or not). If you needed to, you could limit the number of requests to your API endpoints using, for example, SlowAPI (as shown in this answer).
Limiting the usage of the middleware to specific routes only
You could limit the usage of the middleware to specific endpoints by:
checking the request.url.path inside the middleware against a
pre-defined list of routes for which you would like to log the
request and response, as described in this answer (see
"Update" section),
or using a sub application, as demonstrated in this
answer
or using a custom APIRoute class, as demonstrated in Option 2
below.
Working Example
from fastapi import FastAPI, APIRouter, Response, Request
from starlette.background import BackgroundTask
from fastapi.routing import APIRoute
from starlette.types import Message
from typing import Dict, Any
import logging
app = FastAPI()
logging.basicConfig(filename='info.log', level=logging.DEBUG)
def log_info(req_body, res_body):
logging.info(req_body)
logging.info(res_body)
async def set_body(request: Request, body: bytes):
async def receive() -> Message:
return {'type': 'http.request', 'body': body}
request._receive = receive
#app.middleware('http')
async def some_middleware(request: Request, call_next):
req_body = await request.body()
await set_body(request, req_body)
response = await call_next(request)
res_body = b''
async for chunk in response.body_iterator:
res_body += chunk
task = BackgroundTask(log_info, req_body, res_body)
return Response(content=res_body, status_code=response.status_code,
headers=dict(response.headers), media_type=response.media_type, background=task)
#app.post('/')
def main(payload: Dict[Any, Any]):
return payload
In case you would like to perform some validation on the request body—for example, ensruing that the request body size is not exceeding a certain value—instead of using request.body(), you can process the body one chunk at a time using the .stream() method, as shown below (similar to this answer).
#app.middleware('http')
async def some_middleware(request: Request, call_next):
req_body = b''
async for chunk in request.stream():
req_body += chunk
...
Option 2 - Using custom APIRoute class
You can alternatively use a custom APIRoute class—similar to here and here—which, among other things, would allow you to manipulate the request body before it is processed by your application, as well as the response body before it is returned to the client. This option also allows you to limit the usage of this class to the routes you wish, as only the endpoints under the APIRouter (i.e., router in the example below) will use the custom APIRoute class .
It should be noted that the same comments mentioned in Option 1 above, under the "Note" section, apply to this option as well. For example, if your API returns a StreamingResponse—such as in /video route of the example below, which is streaming a video file from an online source (public videos to test this can be found here, and you can even use a longer video than the one used below to see the effect more clearly)—you may come across issues on server side, if your server's RAM can't handle it, as well as delays on client side (and reverse proxy server, if using one) due to the whole (streaming) response being read and stored in RAM, before it is returned to the client (as explained earlier). In such cases, you could exclude such endpoints that return a StreamingResponse from the custom APIRoute class and limit its usage only to the desired routes—especially, if it is a large video file, or even live video that wouldn't likely make much sense to have it stored in the logs—simply by not using the #<name_of_router> decorator (i.e., #router in the example below) for such endpoints, but rather using the #<name_of_app> decorator (i.e., #app in the example below), or some other APIRouter or sub application.
Working Example
from fastapi import FastAPI, APIRouter, Response, Request
from starlette.background import BackgroundTask
from starlette.responses import StreamingResponse
from fastapi.routing import APIRoute
from starlette.types import Message
from typing import Callable, Dict, Any
import logging
import httpx
def log_info(req_body, res_body):
logging.info(req_body)
logging.info(res_body)
class LoggingRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
req_body = await request.body()
response = await original_route_handler(request)
if isinstance(response, StreamingResponse):
res_body = b''
async for item in response.body_iterator:
res_body += item
task = BackgroundTask(log_info, req_body, res_body)
return Response(content=res_body, status_code=response.status_code,
headers=dict(response.headers), media_type=response.media_type, background=task)
else:
res_body = response.body
response.background = BackgroundTask(log_info, req_body, res_body)
return response
return custom_route_handler
app = FastAPI()
router = APIRouter(route_class=LoggingRoute)
logging.basicConfig(filename='info.log', level=logging.DEBUG)
#router.post('/')
def main(payload: Dict[Any, Any]):
return payload
#router.get('/video')
def get_video():
url = 'https://storage.googleapis.com/gtv-videos-bucket/sample/ForBiggerBlazes.mp4'
def gen():
with httpx.stream('GET', url) as r:
for chunk in r.iter_raw():
yield chunk
return StreamingResponse(gen(), media_type='video/mp4')
app.include_router(router)
You may try to customize APIRouter like in FastAPI official documentation:
import time
from typing import Callable
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi.routing import APIRoute
class TimedRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
before = time.time()
response: Response = await original_route_handler(request)
duration = time.time() - before
response.headers["X-Response-Time"] = str(duration)
print(f"route duration: {duration}")
print(f"route response: {response}")
print(f"route response headers: {response.headers}")
return response
return custom_route_handler
app = FastAPI()
router = APIRouter(route_class=TimedRoute)
#app.get("/")
async def not_timed():
return {"message": "Not timed"}
#router.get("/timed")
async def timed():
return {"message": "It's the time of my life"}
app.include_router(router)
As the other answers did not work for me and I searched quite extensively on stackoverflow to fix this problem, I will show my solution below.
The main issue is that when using the request body or response body many of the approaches/solutions offered online do simply not work as the request/response body is consumed in reading it from the stream.
To solve this issue I adapted an approach that basically reconstructs the request and response after reading them. This is heavily based on the comment by user 'kovalevvlad' on https://github.com/encode/starlette/issues/495.
Custom middleware is created that is later added to the app to log all requests and responses. Note that you need some kind of logger to store your logs.
from json import JSONDecodeError
import json
import logging
from typing import Callable, Awaitable, Tuple, Dict, List
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import Scope, Message
# Set up your custom logger here
logger = ""
class RequestWithBody(Request):
"""Creation of new request with body"""
def __init__(self, scope: Scope, body: bytes) -> None:
super().__init__(scope, self._receive)
self._body = body
self._body_returned = False
async def _receive(self) -> Message:
if self._body_returned:
return {"type": "http.disconnect"}
else:
self._body_returned = True
return {"type": "http.request", "body": self._body, "more_body": False}
class CustomLoggingMiddleware(BaseHTTPMiddleware):
"""
Use of custom middleware since reading the request body and the response consumes the bytestream.
Hence this approach to basically generate a new request/response when we read the attributes for logging.
"""
async def dispatch( # type: ignore
self, request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]]
) -> Response:
# Store request body in a variable and generate new request as it is consumed.
request_body_bytes = await request.body()
request_with_body = RequestWithBody(request.scope, request_body_bytes)
# Store response body in a variable and generate new response as it is consumed.
response = await call_next(request_with_body)
response_content_bytes, response_headers, response_status = await self._get_response_params(response)
# Logging
# If there is no request body handle exception, otherwise convert bytes to JSON.
try:
req_body = json.loads(request_body_bytes)
except JSONDecodeError:
req_body = ""
# Logging of relevant variables.
logger.info(
f"{request.method} request to {request.url} metadata\n"
f"\tStatus_code: {response.status_code}\n"
f"\tRequest_Body: {req_body}\n"
)
# Finally, return the newly instantiated response values
return Response(response_content_bytes, response_status, response_headers)
async def _get_response_params(self, response: StreamingResponse) -> Tuple[bytes, Dict[str, str], int]:
"""Getting the response parameters of a response and create a new response."""
response_byte_chunks: List[bytes] = []
response_status: List[int] = []
response_headers: List[Dict[str, str]] = []
async def send(message: Message) -> None:
if message["type"] == "http.response.start":
response_status.append(message["status"])
response_headers.append({k.decode("utf8"): v.decode("utf8") for k, v in message["headers"]})
else:
response_byte_chunks.append(message["body"])
await response.stream_response(send)
content = b"".join(response_byte_chunks)
return content, response_headers[0], response_status[0]

Reading multiple "bulked" jsons from s3 asynchronously. Is there a better way?

The goal is to try to load a large amount of "bulked" jsons from s3. I found aiobotocore and felt urged to try in hope to get more efficiency and at the same time familiarise myself with asyncio. I gave it a shot, and it works but I know basically nada about asynchronous programming. Therefore, I was hoping for some improvements/comments. Maybe there are some kind souls out there that can spot some obvious mistakes.
The problem is that boto3 only supports one http request at a time. By utilising Threadpool I managed to get significant improvements, but I'm hoping for a more efficient way.
Here is the code:
Imports:
import os
import asyncio
import aiobotocore
from itertools import chain
import json
from json.decoder import WHITESPACE
Some helper generator I found somewhere to return decoded jsons from string with multiple jsons.
def iterload(string_or_fp, cls=json.JSONDecoder, **kwargs):
'''helper for parsing individual jsons from string of jsons (stolen from somewhere)'''
string = str(string_or_fp)
decoder = cls(**kwargs)
idx = WHITESPACE.match(string, 0).end()
while idx < len(string):
obj, end = decoder.raw_decode(string, idx)
yield obj
idx = WHITESPACE.match(string, end).end()
This function gets keys from an s3 bucket with a given prefix:
# Async stuff starts here
async def get_keys(loop, bucket, prefix):
'''Get keys in bucket based on prefix'''
session = aiobotocore.get_session(loop=loop)
async with session.create_client('s3', region_name='us-west-2',
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_access_key_id=AWS_ACCESS_KEY_ID) as client:
keys = []
# list s3 objects using paginator
paginator = client.get_paginator('list_objects')
async for result in paginator.paginate(Bucket=bucket, Prefix=prefix):
for c in result.get('Contents', []):
keys.append(c['Key'])
return keys
This function gets the content for a provided key. Untop of that it flattens the list of decoded content:
async def get_object(loop,bucket, key):
'''Get json content from s3 object'''
session = aiobotocore.get_session(loop=loop)
async with session.create_client('s3', region_name='us-west-2',
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_access_key_id=AWS_ACCESS_KEY_ID) as client:
# get object from s3
response = await client.get_object(Bucket=bucket, Key=key)
async with response['Body'] as stream:
content = await stream.read()
return list(iterload(content.decode()))
Here is the main function which gathers the contents for all the found keys and flattens the list of contents.
async def go(loop, bucket, prefix):
'''Returns list of dicts of object contents'''
session = aiobotocore.get_session(loop=loop)
async with session.create_client('s3', region_name='us-west-2',
aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
aws_access_key_id=AWS_ACCESS_KEY_ID) as client:
keys = await get_keys(loop, bucket, prefix)
contents = await asyncio.gather(*[get_object(loop, bucket, k) for k in keys])
return list(chain.from_iterable(contents))
Finally, I run this and the result list of dicts ends up nicely in result
loop = asyncio.get_event_loop()
result = loop.run_until_complete(go(loop, 'some-bucket', 'some-prefix'))
One thing that I think might be a bit wierd is that I create a client in each async function. Probably that can be lifted out. Note sure about how aiobotocore works with multiple clients.
Furthermore, I think that you would not need to await that all keys are loaded before loading the objects for the keys, which I think is the case in this implementation. I'm assuming that as soon as a key is found you could call get_object. So, maybe it should be an async generator. But I'm not completely in the clear here.
Thank you in advance! Hope this helps someone in a similar situation.
first check out aioboto3
second, each client in aiobotocore is associated with an aiohttp session. Each session can have up to max_pool_connections. This is why in the basic aiobotocore example it does an async with on the create_client. So the pool is closed when done using the client.
Here are some tips:
You should use a work pool, created by me, modularized by CaliDog to avoid polluting your event loop. When using this think of your workflow as a stream.
This will avoid you having to use asyncio.gather, which will leave tasks running in the background after the first exception is thrown.
You should tune your work loop size and max_pool_connections together, and only use one client with the number of tasks you want to (or can based on compute required) support in parallel.
You really don't need to pass the loop around as with modern python versions there's one loop per thread
You should use aws profiles (profile param to Session init)/environment variables so you don't need to hardcode key and region information.
Based on the above here is how I would do it:
import asyncio
from itertools import chain
import json
from typing import List
from json.decoder import WHITESPACE
import logging
from functools import partial
# Third Party
import asyncpool
import aiobotocore.session
import aiobotocore.config
_NUM_WORKERS = 50
def iterload(string_or_fp, cls=json.JSONDecoder, **kwargs):
# helper for parsing individual jsons from string of jsons (stolen from somewhere)
string = str(string_or_fp)
decoder = cls(**kwargs)
idx = WHITESPACE.match(string, 0).end()
while idx < len(string):
obj, end = decoder.raw_decode(string, idx)
yield obj
idx = WHITESPACE.match(string, end).end()
async def get_object(s3_client, bucket: str, key: str):
# Get json content from s3 object
# get object from s3
response = await s3_client.get_object(Bucket=bucket, Key=key)
async with response['Body'] as stream:
content = await stream.read()
return list(iterload(content.decode()))
async def go(bucket: str, prefix: str) -> List[dict]:
"""
Returns list of dicts of object contents
:param bucket: s3 bucket
:param prefix: s3 bucket prefix
:return: list of dicts of object contents
"""
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
session = aiobotocore.session.AioSession()
config = aiobotocore.config.AioConfig(max_pool_connections=_NUM_WORKERS)
contents = []
async with session.create_client('s3', config=config) as client:
worker_co = partial(get_object, client, bucket)
async with asyncpool.AsyncPool(None, _NUM_WORKERS, 's3_work_queue', logger, worker_co,
return_futures=True, raise_on_join=True, log_every_n=10) as work_pool:
# list s3 objects using paginator
paginator = client.get_paginator('list_objects')
async for result in paginator.paginate(Bucket=bucket, Prefix=prefix):
for c in result.get('Contents', []):
contents.append(await work_pool.push(c['Key']))
# retrieve results from futures
contents = [c.result() for c in contents]
return list(chain.from_iterable(contents))
_loop = asyncio.get_event_loop()
_result = _loop.run_until_complete(go('some-bucket', 'some-prefix'))

Categories

Resources