I have a project with many async functions (they do http call to a remote server).
The server can handle up to k simultaneous request (it depends of which request i do)
I want to be able to easily implement "do this N requests with up to K workers". I could not find a way without either duplicating code or loosing type hints.
How to do it ?
I tried to use Semaphore but failed to use it, mainly because I could not find a way to move this logic at the caller scope.
implement logic on caller scope (not working)
import asyncio
import time
async def sleep(s: int):
await asyncio.sleep(s)
print(f'task done in {s}s')
return s
sem = asyncio.Semaphore(2)
# Expect to ends in 8 seconds, ends in 6 (not the correct behavior)
start = time.time()
async with sem:
res = await asyncio.gather(*[
sleep(6),
sleep(2),
sleep(2),
sleep(2),
sleep(2),
sleep(2),
])
end = time.time()
print(f"ended in {end-start}s")
Refactoring async functions (code duplication)
import asyncio
import time
from typing import Optional
async def sleep(s: int):
await asyncio.sleep(s)
async def request_0(arg0: int, sem: Optional[asyncio.Semaphore]=None):
if sem is not None:
async with sem:
await sleep(2) # Do http call 0
else:
await sleep(2) # Do http call 0
async def request_1(arg1: str, sem: Optional[asyncio.Semaphore]=None):
if sem is not None:
async with sem:
await sleep(1) # Do http call 1
else:
await sleep(1) # Do http call 1
async def request_2(arg2: float, arg3: str, sem: Optional[asyncio.Semaphore]=None):
if sem is not None:
async with sem:
await sleep(1) # Do http call 2
else:
await sleep(1) # Do http call 2
start = time.time()
async with asyncio.Semaphore(2) as sem:
res = await asyncio.gather(*[
request_0(arg0=0),
request_1(arg1='0'),
request_2(arg2=0, arg3='0'),
])
end = time.time()
print(f"ended in {end-start}s")
Enhance functions with annotation (losing type hints)
# type: ignore
import asyncio
import time
from typing import Callable, Coroutine, Optional, Any
async def sleep(s: int):
await asyncio.sleep(s)
def semaphoreUseable(func):
async def wrapper(*args, sem=Optional[asyncio.Semaphore], **kwargs):
if sem is not None:
async with sem:
return await func(*args, **kwargs)
else:
return await func(*args, **kwargs)
return wrapper
#semaphoreUseable
async def request_0(arg0: int):
await sleep(2) # http call 0
#semaphoreUseable
async def request_1(arg1: str):
await sleep(1) # http call 1
#semaphoreUseable
async def request_2(arg2: float, arg3: str):
await sleep(1) # http call 2
async def main():
start = time.time()
async with asyncio.Semaphore(2) as sem:
res = await asyncio.gather(*[
request_0(arg0=0, sem=sem),
request_1(arg1='0', sem=sem),
request_2(arg2=0, arg3='0', sem=sem),
])
end = time.time()
print(f"ended in {end-start}s")
if __name__ == "__main__":
asyncio.run(main())
related question: How to type a function with Callable without losing keyword argument? (could not find a way to type the annotation solution)
I'm trying to test an async request but I didn't find how to do. I tried with patch decorator, with AsyncMock... Everytime, I had either aexit error or AsyncMock can't be used in await expression... Where am I wrong ?
class RequestService:
async def requestPostPicture(self, session: aiohttp.ClientSession, photoData: dict):
try:
with aiohttp.MultipartWriter('form-data') as mpwriter:
part = mpwriter.append(photoData['file'][1],{'content-type': photoData['file'][2]})
part.set_content_disposition('form-data', name='file', filename=photoData['file'][0])
async with session.post('https://www.api-url.com', data=mpwriter, headers=self.headers) as resp:
if isinstance(resp, dict):
return resp
apiResponse = await resp.json
return apiResponse
except Exception as error:
return {'error': str(error)}
My test :
class TestRequestService(IsolatedAsyncioTestCase):
#patch('aiohttp.ClientSession.post')
async def testRequestPostPictureDict(self, mockPost):
mockPost.__aenter__.return_value = {"error": "test"}
requestservice = RequestService()
pictureTest = {'file': ('photodatatest.jpg', 'photodatatest', 'image/jpeg')}
connector = aiohttp.TCPConnector(limit=15)
async with aiohttp.ClientSession(connector=connector) as sessionPicture:
returnValue = await requestservice.requestPostPicture(sessionPicture, pictureTest)
self.assertEqual(returnValue, {'error': 'test'})
async def testRequestPostPictureDict(self):
mock = aiohttp.ClientSession
mock.post = MagicMock()
mock.post.return_value.__aenter__.return_value = {'error': 'test'}
The code below is intended to send multiple HTTP requests asynchronously in a while loop, and depending on the response from each request(request "X" always returns "XXX", "Y" always returns "YYY" and so on), do something and sleep for interval seconds specified for each request.
However, it throws an error...
RuntimeError: cannot reuse already awaited coroutine
Could anyone help me how I could fix the code to realise the intended behaviour?
class Client:
def __init__(self):
pass
async def run_forever(self, coro, interval):
while True:
res = await coro
await self._onresponse(res, interval)
async def _onresponse(self, res, interval):
if res == "XXX":
# ... do something with the resonse ...
await asyncio.sleep(interval)
if res == "YYY":
# ... do something with the resonse ...
await asyncio.sleep(interval)
if res == "ZZZ":
# ... do something with the resonse ...
await asyncio.sleep(interval)
async def request(something):
# ... HTTP request using aiohttp library ...
return response
async def main():
c = Client()
await c.run_forever(request("X"), interval=1)
await c.run_forever(request("Y"), interval=2)
await c.run_forever(request("Z"), interval=3)
# ... and more
As the error says, you can't await a coroutine more than once. Instead of passing a coroutine into run_forever and then awaiting it in a loop, passing the coroutine's argument(s) instead and await a new coroutine each iteration of the loop.
class Client:
async def run_forever(self, value, interval):
while True:
res = await rqequest(value)
await self._response(response, interval)
You also need to change how you await run_forever. await is blocking, so when you await something with an infinite loop, you'll never reach the next line. Instead, you want to gather multiple coroutines as once.
async def main():
c = Client()
await asyncio.gather(
c.run_forever("X", interval=1),
c.run_forever("Y", interval=2),
c.run_forever("Z", interval=3),
)
I have such middleware
class RequestContext(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
request_id = request_ctx.set(str(uuid4())) # generate uuid to request
body = await request.body()
if body:
logger.info(...) # log request with body
else:
logger.info(...) # log request without body
response = await call_next(request)
response.headers['X-Request-ID'] = request_ctx.get()
logger.info("%s" % (response.status_code))
request_ctx.reset(request_id)
return response
So the line body = await request.body() freezes all requests that have body and I have 504 from all of them. How can I safely read the request body in this context? I just want to log request parameters.
I would not create a Middleware that inherits from BaseHTTPMiddleware since it has some issues, FastAPI gives you a opportunity to create your own routers, in my experience this approach is way better.
from fastapi import APIRouter, FastAPI, Request, Response, Body
from fastapi.routing import APIRoute
from typing import Callable, List
from uuid import uuid4
class ContextIncludedRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
request_id = str(uuid4())
response: Response = await original_route_handler(request)
if await request.body():
print(await request.body())
response.headers["Request-ID"] = request_id
return response
return custom_route_handler
app = FastAPI()
router = APIRouter(route_class=ContextIncludedRoute)
#router.post("/context")
async def non_default_router(bod: List[str] = Body(...)):
return bod
app.include_router(router)
Works as expected.
b'["string"]'
INFO: 127.0.0.1:49784 - "POST /context HTTP/1.1" 200 OK
In case you still wanted to use BaseHTTP, I recently ran into this problem and came up with a solution:
Middleware Code
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import json
from .async_iterator_wrapper import async_iterator_wrapper as aiwrap
class some_middleware(BaseHTTPMiddleware):
async def dispatch(self, request:Request, call_next:RequestResponseEndpoint):
# --------------------------
# DO WHATEVER YOU TO DO HERE
#---------------------------
response = await call_next(request)
# Consuming FastAPI response and grabbing body here
resp_body = [section async for section in response.__dict__['body_iterator']]
# Repairing FastAPI response
response.__setattr__('body_iterator', aiwrap(resp_body)
# Formatting response body for logging
try:
resp_body = json.loads(resp_body[0].decode())
except:
resp_body = str(resp_body)
async_iterator_wrapper Code from
TypeError from Python 3 async for loop
class async_iterator_wrapper:
def __init__(self, obj):
self._it = iter(obj)
def __aiter__(self):
return self
async def __anext__(self):
try:
value = next(self._it)
except StopIteration:
raise StopAsyncIteration
return value
I really hope this can help someone! I found this very helpful for logging.
Big thanks to #Eddified for the aiwrap class
You can do this safely with a generic ASGI middleware:
from typing import Iterable, List, Protocol, Generator
import pytest
from starlette.responses import Response
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Scope, Send, Receive, Message
class Logger(Protocol):
def info(self, message: str) -> None:
...
class BodyLoggingMiddleware:
def __init__(
self,
app: ASGIApp,
logger: Logger,
) -> None:
self.app = app
self.logger = logger
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
done = False
chunks: "List[bytes]" = []
async def wrapped_receive() -> Message:
nonlocal done
message = await receive()
if message["type"] == "http.disconnect":
done = True
return message
body = message.get("body", b"")
more_body = message.get("more_body", False)
if not more_body:
done = True
chunks.append(body)
return message
try:
await self.app(scope, wrapped_receive, send)
finally:
while not done:
await wrapped_receive()
self.logger.info(b"".join(chunks).decode()) # or somethin
async def consume_body_app(scope: Scope, receive: Receive, send: Send) -> None:
done = False
while not done:
msg = await receive()
done = "more_body" not in msg
await Response()(scope, receive, send)
async def consume_partial_body_app(scope: Scope, receive: Receive, send: Send) -> None:
await receive()
await Response()(scope, receive, send)
class TestException(Exception):
pass
async def consume_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None:
done = False
while not done:
msg = await receive()
done = "more_body" not in msg
raise TestException
async def consume_partial_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None:
await receive()
raise TestException
class TestLogger:
def __init__(self, recorder: List[str]) -> None:
self.recorder = recorder
def info(self, message: str) -> None:
self.recorder.append(message)
#pytest.mark.parametrize(
"chunks, expected_logs", [
([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]),
]
)
#pytest.mark.parametrize(
"app",
[consume_body_app, consume_partial_body_app]
)
def test_body_logging_middleware_no_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None:
logs: List[str] = []
client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs)))
def chunk_gen() -> Generator[bytes, None, None]:
yield from iter(chunks)
resp = client.get("/", data=chunk_gen())
assert resp.status_code == 200
assert logs == expected_logs
#pytest.mark.parametrize(
"chunks, expected_logs", [
([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]),
]
)
#pytest.mark.parametrize(
"app",
[consume_body_and_error_app, consume_partial_body_and_error_app]
)
def test_body_logging_middleware_with_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None:
logs: List[str] = []
client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs)))
def chunk_gen() -> Generator[bytes, None, None]:
yield from iter(chunks)
with pytest.raises(TestException):
client.get("/", data=chunk_gen())
assert logs == expected_logs
if __name__ == "__main__":
import os
pytest.main(args=[os.path.abspath(__file__)])
Turns out await request.json() can only be called once per the request cycle. So if you need to access the request body in multiple middlewares for filtering or authentication etc then there's a work around which is to create a custom middleware that copies the contents of request body in request.state. The middleware should be loaded as early as necessary. Each middleware next in chain or controller can then access the request body from request.state instead of calling await request.json() again. Here's a example:
class CopyRequestMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request_body = await request.json()
request.state.body = request_body
response = await call_next(request)
return response
class LogRequestMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Since it'll be loaded after CopyRequestMiddleware it can access request.state.body.
request_body = request.state.body
print(request_body)
response = await call_next(request)
return response
The controller will access request body from request.state as well
request_body = request.state.body
Just because such solution not stated yet, but it's worked for me:
from typing import Callable, Awaitable
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import StreamingResponse
from starlette.concurrency import iterate_in_threadpool
class LogStatsMiddleware(BaseHTTPMiddleware):
async def dispatch( # type: ignore
self, request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]],
) -> Response:
response = await call_next(request)
response_body = [section async for section in response.body_iterator]
response.body_iterator = iterate_in_threadpool(iter(response_body))
logging.info(f"response_body={response_body[0].decode()}")
return response
def init_app(app):
app.add_middleware(LogStatsMiddleware)
iterate_in_threadpool actually making from iterator object async Iterator
If you look on implementation of starlette.responses.StreamingResponse you'll see, that this function used exactly for this
If you only want to read request parameters, best solution i found was to implement a "route_class" and add it as arg when creating the fastapi.APIRouter, this is because parsing the request within the middleware is considered problematic
The intention behind the route handler from what i understand is to attach exceptions handling logic to specific routers, but since it's being invoked before every route call, you can use it to access the Request arg
Fastapi documentation
You could do something as follows:
class MyRequestLoggingRoute(APIRoute):
def get_route_handler(self) -> Callable:
original_route_handler = super().get_route_handler()
async def custom_route_handler(request: Request) -> Response:
body = await request.body()
if body:
logger.info(...) # log request with body
else:
logger.info(...) # log request without body
try:
return await original_route_handler(request)
except RequestValidationError as exc:
detail = {"errors": exc.errors(), "body": body.decode()}
raise HTTPException(status_code=422, detail=detail)
return custom_route_handler
The issue is in Uvicorn. The FastAPI/Starlette::Request class does cache the body, but the Uvicorn function RequestResponseCycle::request() does not, so if you instantiate two or more Request classes and ask for the body(), only the instance that asks for the body first will have a valid body.
I solved creating a mock function that returns a cached copy of the request():
class LogRequestsMiddleware:
def __init__(self, app:ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
receive_cached_ = await receive()
async def receive_cached():
return receive_cached_
request = Request(scope, receive = receive_cached)
# do what you need here
await self.app(scope, receive_cached, send)
app.add_middleware(LogRequestsMiddleware)
I am writing a helper class for handling multiple urls request in asynchronous way. The code is following.
class urlAsyncClient(object):
def __init__(self, url_arr):
self.url_arr = url_arr
async def async_worker(self):
result = await self.__run()
return result
async def __run(self):
pending_req = []
async with aiohttp.ClientSession() as session:
for url in self.url_arr:
r = self.__fetch(session, url)
pending_req.append(r)
#Awaiting the results altogether instead of one by one
result = await asyncio.wait(pending_req)
return result
#staticmethod
async def __fetch(session, url):
async with session.get(url) as response: #ERROR here
status_code = response.status
if status_code == 200:
return await response.json()
else:
result = await response.text()
print('Error ' + str(response.status_code) + ': ' + result)
return {"error": result}
As awaiting the result one by one seems meaningless in asynchronous. I put them into an array and wait together by await asyncio.wait(pending_req).
But seems like it is not the correct way to do it as I get the following error
in __fetch async with session.get(url) as response: RuntimeError: Session is closed
May I know the correct way to do it? Thanks.
because session has closed before you await it
async with aiohttp.ClientSession() as session:
for url in self.url_arr:
r = self.__fetch(session, url)
pending_req.append(r)
#session closed hear
you can make session an argument to __run, like this
async def async_worker(self):
async with aiohttp.ClientSession() as session:
result = await self.__run(session)
return result
# session will close hear
async def __run(self, session):
pending_req = []
for url in self.url_arr:
r = self.__fetch(session, url)
pending_req.append(r)
#Awaiting the results altogether instead of one by one
result = await asyncio.wait(pending_req)
return result