I have this function that is called as a dependency in all of my APIs.
from fastapi import HTTPException, Request, status
async def get_user_id_or_401(request: Request) -> str:
user_id: str = request.headers.get("x-cognito-user-id")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
return user_id
How do I unit test this function? My initial move was to patch request.headers.get
but i don't know how to go about that.
So, my question is, how do I patch request.headers.get
. However, if there is a better way I can go about testing this, please tell me.
CodePudding user response:
Thats what overrides are for https://fastapi.tiangolo.com/advanced/testing-dependencies/
Something like this:
from fastapi.testclient import TestClient
from main import app
client = TestClient(app)
async def mock_user_id(request):
return "foo"
app.dependency_overrides[get_user_id_or_401] = mock_user_id
def test_request():
response = client.get("/")
assert response.status_code == 200
But I think it would be better to make the requests during test such that this isn't required by adding the header to it. That would also test your function there, which is otherwise not covered.
def test_request():
response = client.get("/", headers={"x-cognito-user-id": "foo"})
assert response.status_code == 200
If you want to unit test your user id function, you could import it into your test file and instantiate some mock requests.
Note, the header keys and values should be bytes. So the type of headers is like list[tuple[byte, byte]]
.
from starlette import status
from starlette.exceptions import HTTPException
from starlette.requests import Request
def get_user_id_or_401(request: Request) -> str:
user_id: str = request.headers.get("x-cognito-user-id")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
)
return user_id
def test_uid():
r = Request(scope={
"type": "http",
"headers": [("x-cognito-user-id".encode(), "foo".encode())],
})
assert get_user_id_or_401(r) == "foo"