DEV Community

whchi
whchi

Posted on

Testing FastAPI with async database session

Get started

FastAPI uses Python's asyncio module to improve its I/O performance.

According to the official documentation, when using path or Depends, it will always be asynchronous, regardless of whether you use async def (to run in coroutines) or def (to run in the thread pool).

When you use async def for your function, you MUST use the await keyword avoid "sequence" behavior.

This behavior is slightly different from JavaScript's async-await, which could be the subject of another significant discussion.

The code

Suppose your application looks like this

# create db connection
engine = create_async_engine(
    url=get_db_settings().async_connection_string,
    echo=True,
)

async_session_global = sessionmaker(
    autocommit=False,
    autoflush=False,
    bind=engine,
    class_=AsyncSession,
    expire_on_commit=False,
)


async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
    async with async_session_global.begin() as session:
        try:
            yield session
        except:
            await session.rollback()
            raise
        finally:
            await session.close()

# defind fastapi application
app = FastAPI()

router = APIRouter()
@router.get('/api/async-examples/{id}')
def get_example(id: int, db = Depends(get_async_session)):
    return await db.execute(select(Example)).all()

@router.put('/api/async-examples/{id}')
def put_example(id: int, db = Depends(get_async_session)):
    await db.execute(update(Example).where(id=id).values(name='testtest', age=123))
    await db.commit()
    await db.refresh(Example)
    return await db.execute(select(Example).filter_by(id=id)).scalar_one()

app.include_router(router)
Enter fullscreen mode Exit fullscreen mode

Firstly, we need fixtures for our tests. Here, I'll be using asyncpg as my async database connector.

# conftest.py
import asyncio
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from fastapi import FastAPI
import pytest

engine = create_async_engine(
    url='postgresql+asyncpg://...',
    echo=True,
)

# drop all database every time when test complete
@pytest.fixture(scope='session')
async def async_db_engine():
    async with async_engine.begin() as conn:
        await conn.run_sync(SQLModel.metadata.create_all)

    yield async_engine

    async with async_engine.begin() as conn:
        await conn.run_sync(SQLModel.metadata.drop_all)

# truncate all table to isolate tests
@pytest.fixture(scope='function')
async def async_db(async_db_engine):
    async_session = sessionmaker(
        expire_on_commit=False,
        autocommit=False,
        autoflush=False,
        bind=async_db_engine,
        class_=AsyncSession,
    )

    async with async_session() as session:
        await session.begin()

        yield session

        await session.rollback()

        for table in reversed(SQLModel.metadata.sorted_tables):
            await session.execute(f'TRUNCATE {table.name} CASCADE;')
            await session.commit()

@pytest.fixture(scope='session')
async def async_client() -> AsyncClient:
    return AsyncClient(app=FastAPI(), base_url='http://localhost')

# let test session to know it is running inside event loop
@pytest.fixture(scope='session')
def event_loop():
    policy = asyncio.get_event_loop_policy()
    loop = policy.new_event_loop()
    yield loop
    loop.close()

# assume we have a example model
@pytest.fixture
async def async_example_orm(async_db: AsyncSession) -> Example:
    example = Example(name='test', age=18, nick_name='my_nick')
    async_db.add(example)
    await async_db.commit()
    await async_db.refresh(example)
    return example


Enter fullscreen mode Exit fullscreen mode

Then, write our tests

# test_what_ever_you_want.py
# make all test mark with `asyncio`
pytestmark = pytest.mark.asyncio

async def test_get_example(async_client: AsyncClient, async_db: AsyncSession,
                           async_example_orm: Example) -> None:
    response = await async_client.get(f'/api/async-examples/{async_example_orm.id}')

    assert response.status_code == status.HTTP_200_OK
    assert (await async_db.execute(select(Example).filter_by(id=async_example_orm.id)
                                  )).scalar_one().id == async_example_orm.id

async def test_update_example(async_client: AsyncClient, async_db: AsyncSession,
                              async_example_orm: Example) -> None:
    payload = {'name': 'updated_name', 'age': 20}

    response = await async_client.put(f'/api/async-examples/{async_example_orm.id}',
                                      json=payload)
    assert response.status_code == status.HTTP_200_OK
    await async_db.refresh(async_example_orm)
    assert (await
            async_db.execute(select(Example).filter_by(id=async_example_orm.id)
                            )).scalar_one().name == response.json()['data']['name']
Enter fullscreen mode Exit fullscreen mode

The key here is async_db and event_loop, and also you have to make sure your program's db session does not using global commit.

Top comments (10)

Collapse
 
scholli profile image
Tom Stein

Nice idea! How do you override the get_async_session ("get_db") dependency used by FastAPI endpoints to query data with the new one from the tests? I don't see you using app.dependency_overrides anywhere.

Collapse
 
whchi profile image
whchi

I didn't override my database connection, I use a test database for testing.

Collapse
 
scholli profile image
Tom Stein

Ahh okay, I see. That is also a viable option. In my case, I want to use a local SQLite in-memory DB for testing. I achieved this using the following code:

@pytest_asyncio.fixture(scope="function")
async def async_client(async_db_session: AsyncSession) -> AsyncClient:
    """Setup the test client for the FastAPI app.

    Returns:
        AsyncClient: the async httpx test client to use in the tests.
    """

    def override_get_db() -> Iterator[AsyncSession]:
        """Utility function to wrap the database session in a generator.

        Yields:
            Iterator[AsyncSession]: An iterator containing one database session.
        """
        yield async_db_session

    app.dependency_overrides[get_db] = override_get_db
    return AsyncClient(app=app, base_url="http://test-server")
Enter fullscreen mode Exit fullscreen mode
Thread Thread
 
whchi profile image
whchi

oh, I know what you mean.

I separate my environment to archive what you did by using the pytest-env package.

  • pyproject.toml
[tool.pytest.ini_options]
env=[
  "APP_ENV=test",
  "DB_NAME=test_db",
  "DB_HOST=localhost",
  "DB_USER=test_root",
  "DB_PASSWORD=test_pwd"
]
Enter fullscreen mode Exit fullscreen mode

And my db connection is wrapped in another module

  • db_connection.py
from sqlalchemy.ext.asyncio import create_async_engine

engine = create_async_engine(
    url=f'postgresql+asyncpg://{env.DB_USER}:{env.DB_PASSWORD}@{env.DB_HOST}:5432/{env.DB_NAME}',
    echo=True,
)

Enter fullscreen mode Exit fullscreen mode

So I didn't do app.dependency_overrides[get_db] to connect my test database.

Collapse
 
andreybzz profile image
Andrey Babichev

And second one, any actions deep inside after await session.begin() may call .commit which brakes this approach.

Collapse
 
whchi profile image
whchi • Edited

Yes, but it's hidden in the database(or coroutine). I prefer explicitly calling 'commit' to let developers know what is happening.

Collapse
 
andreybzz profile image
Andrey Babichev • Edited

Let me suggest you another solution with a wrapped transaction.

@pytest_asyncio.fixture(scope="session")
async def async_db_connection() -> AsyncGenerator[AsyncConnection, None]:
    async_engine = create_async_engine(
        ASYNC_DATABASE_URL, echo=False, connect_args={"timeout": 0.5}
    )

    async with async_engine.begin() as conn:
        # separate connection because .create_all makes .commit inside
        await conn.run_sync(SQLModel.metadata.create_all)

    conn = await async_engine.connect()
    try:
        yield conn
    except:
        raise
    finally:
        await conn.rollback()

    async with async_engine.begin() as conn:
        await conn.run_sync(SQLModel.metadata.drop_all)

    await async_engine.dispose()


async def __session_within_transaction(
    async_db_connection: AsyncConnection,
) -> AsyncGenerator[AsyncSession, None]:
    async_session_maker = sessionmaker(
        expire_on_commit=False,
        autocommit=False,
        autoflush=False,
        bind=async_db_connection,
        class_=AsyncSession,
    )
    transaction = await async_db_connection.begin()

    yield async_scoped_session(async_session_maker, scopefunc=current_task)

    # no need to truncate, all data will be rolled back
    await transaction.rollback()


@pytest_asyncio.fixture(scope="function")
async def async_db_session(
    async_db_connection: AsyncConnection,
) -> AsyncGenerator[AsyncSession, None]:
    async for session in __session_within_transaction(async_db_connection):
        # setup some data per function
        yield session
Enter fullscreen mode Exit fullscreen mode
Thread Thread
 
whchi profile image
whchi • Edited

I didn't know there was an async_scoped_session. It looks much better and less redundant. Thank you!

Collapse
 
andreybzz profile image
Andrey Babichev

What the reason to use truncate? Shouldn't rollback prevent to save data?

Collapse
 
whchi profile image
whchi

In my code base there are many implicit/explicit commit which some data may not rollback correctly, so I add an extra truncate action