Home > OS >  How to patch a FastAPI Request
How to patch a FastAPI Request

Time:08-09

I have this function that is called as a dependency in all of my APIs.

from fastapi import HTTPException, Request, status


async def get_user_id_or_401(request: Request) -> str:
    user_id: str = request.headers.get("x-cognito-user-id")

    if not user_id:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
        )
    return user_id

How do I unit test this function? My initial move was to patch request.headers.get but i don't know how to go about that.

So, my question is, how do I patch request.headers.get. However, if there is a better way I can go about testing this, please tell me.

CodePudding user response:

Thats what overrides are for https://fastapi.tiangolo.com/advanced/testing-dependencies/

Something like this:

from fastapi.testclient import TestClient
from main import app


client = TestClient(app)

async def mock_user_id(request):
    return "foo"

app.dependency_overrides[get_user_id_or_401] = mock_user_id

def test_request():
    response = client.get("/")
    assert response.status_code == 200

But I think it would be better to make the requests during test such that this isn't required by adding the header to it. That would also test your function there, which is otherwise not covered.

def test_request():
    response = client.get("/", headers={"x-cognito-user-id": "foo"})
    assert response.status_code == 200

If you want to unit test your user id function, you could import it into your test file and instantiate some mock requests.

Note, the header keys and values should be bytes. So the type of headers is like list[tuple[byte, byte]].

from starlette import status
from starlette.exceptions import HTTPException
from starlette.requests import Request


def get_user_id_or_401(request: Request) -> str:
    user_id: str = request.headers.get("x-cognito-user-id")

    if not user_id:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail="Unauthorized"
        )
    return user_id


def test_uid():
    r = Request(scope={
        "type": "http",
        "headers": [("x-cognito-user-id".encode(), "foo".encode())],
    })
    assert get_user_id_or_401(r) == "foo"
  • Related