FastAPI с асинхронным SQLAlchemy, celery и websockets

Начиная с версии 1.4 SQLAlchemy поддерживает asyncio. В этом руководстве мы попытаемся реализовать простой проект с использованием асинхронной функции SQLAlchemy, шифрования, celery и websocket. Но прежде всего, давайте начнем с подключения к базе данных.

Настройка БД с помощью асинхронного SQLAlchemy

Прежде всего, давайте создадим асинхронную сессию:

from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

from app.core.config import settings


engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URL, echo=True)
SessionLocal = sessionmaker(
    expire_on_commit=False,
    class_=AsyncSession,
    bind=engine,
)

Теперь мы можем внедрить эту сессию в наше представление с помощью зависимостей FastAPI:

async def get_db() -> AsyncSession:
    async with SessionLocal() as session:
        yield session

Все готово, поэтому пришло время использовать DB. В нашем проекте мы будем использовать простую аутентификацию с помощью токенов, поэтому нам нужны две таблицы DB: users и user_tokens

from sqlalchemy.orm import declarative_base
from sqlalchemy_utils import EmailType, force_auto_coercion, PasswordType

Base = declarative_base()
force_auto_coercion()


class User(Base):
    __tablename__ = "users"

    id = Column(Integer, primary_key=True, index=True)
    name = Column(String(50))
    email = Column(EmailType(50), unique=True, nullable=False)
    password = Column(PasswordType(schemes=["pbkdf2_sha512"]), nullable=False)

    tokens = relationship(
        "UserToken",
        back_populates="user",
        lazy='dynamic',
        cascade="all, delete-orphan",
    )

class UserToken(Base):
    __tablename__ = "user_tokens"

    id = Column(Integer, primary_key=True, index=True)
    user_id = Column(
        Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
    )
    token = Column(
        UUID(as_uuid=True), unique=True, nullable=False, default=uuid.uuid4
    )
    expires = Column(DateTime)

    user = relationship("User", back_populates="tokens", lazy='joined')Note `force_auto_coercion()`

Обратите внимание, что мы используем force_auto_coercion() перед моделями. Это помогает убедиться, что пароли хэшируются перед сохранением записи в БД.

На этом этапе хорошо бы добавить какой-нибудь инструмент миграции для обработки обновлений метаданных базы данных. Мы будем использовать alembic для этой цели.

Установите alembic:

pip install alembic

И инициализируйте его:

alembic init migrations

Вышеприведенная команда создаст каталог migrations с файлами env.py, README и script.py.mako.

Чтобы заставить alembic работать с нашей базой данных, нам нужно обновить файл env.py:

import asyncio
import os
from logging.config import fileConfig

from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine
from alembic import context

config = context.config

if config.config_file_name is not None:
    fileConfig(config.config_file_name)

# Here we importing and specifying our DB metadata
from app.db.base import Base  
target_metadata = Base.metadata


# This method returns url of our DB
def get_url():
    return os.getenv("SQLALCHEMY_DATABASE_URL", "")


def run_migrations_offline() -> None:
    """Run migrations in 'offline' mode.

    This configures the context with just a URL
    and not an Engine, though an Engine is acceptable
    here as well.  By skipping the Engine creation
    we don't even need a DBAPI to be available.

    Calls to context.execute() here emit the given string to the
    script output.

    """
    # Specify which database we use with alembic
    url = get_url()
    context.configure(
        url=url,
        target_metadata=target_metadata,
        literal_binds=True,
        dialect_opts={"paramstyle": "named"},
    )

    with context.begin_transaction():
        context.run_migrations()


def do_run_migrations(connection: Connection) -> None:
    context.configure(connection=connection, target_metadata=target_metadata)

    with context.begin_transaction():
        context.run_migrations()


async def run_migrations_online() -> None:
    """Run migrations in 'online' mode.

    In this scenario we need to create an Engine
    and associate a connection with the context.

    """
    configuration = config.get_section(config.config_ini_section)
    configuration["sqlalchemy.url"] = get_url()
    connectable = AsyncEngine(
        engine_from_config(
            configuration,
            prefix="sqlalchemy.",
            poolclass=pool.NullPool,
            future=True,
        )
    )

    async with connectable.connect() as connection:
        await connection.run_sync(do_run_migrations)

    await connectable.dispose()


if context.is_offline_mode():
    run_migrations_offline()
else:
    asyncio.run(run_migrations_online())

Для создания миграции просто выполните следующую команду:

alembic revision --autogenerate -m "Added required tables"

Чтобы применить миграции и обновить базу данных, выполните следующие действия:

alembic upgrade head

Теперь наша база данных готова, и мы можем попробовать создать нового пользователя и токен:

from sqlalchemy import select
from app.db.base import User

async def get_user_by_email(db: AsyncSession, email: str) -> User:
    statement = select(User).where(User.email == email)
    result = await db.execute(statement)
    return result.scalars().first()

async def create_user(db: AsyncSession, user: UserCreate) -> User:
    db_user = User(
        email=user.email,
        name=user.name,
        password=user.password,
    )
    db.add(db_user)
    await db.commit()
    await db.refresh(db_user)
    return db_user


async def create_user_token(db: AsyncSession, user: User) -> UserToken:
    db_token = UserToken(
        user=user, expires=datetime.now() + timedelta(weeks=2)
    )
    db.add(db_token)
    await db.commit()
    return db_token

И используйте этот код для регистрации нового пользователя:

from fastapi import APIRouter, FastAPI
from pydantic import BaseModel
from app.crud import crud_user

app = FastAPI()

router = APIRouter()

class UserBase(BaseModel):
    email: EmailStr
    name: str

class UserCreate(UserBase):
    password: constr(strip_whitespace=True, min_length=8)

class User(UserBase):
    id: Optional[int] = None
    token: TokenBase | None = None

    class Config:
        orm_mode = True

@router.post("/sign-up/", response_model=User)
async def create_user(user: UserCreate, db: DBSession):
    user_db = await crud_user.get_user_by_email(db, email=user.email)
    if user_db:
        raise HTTPException(status_code=400, detail="User already registered")
    user = await crud_user.create_user(db, user=user)
    user.token = await crud_user.create_user_token(db, user=user)
    return user

app.include_router(user_routes)

Добавить тесты

Ура! Мы реализовали логику регистрации, и было бы неплохо добавить несколько тестов, чтобы проверить, что все работает так, как ожидалось. Поскольку мы используем асинхронное подключение к БД, нам понадобится асинхронный тест. Поэтому давайте добавим несколько специальных приспособлений, которые нам помогут:

import asyncio

import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

from app.api.deps import get_db
from app.core.config import settings
from app.db.base import Base
from app.main import app


@pytest.fixture(scope="session")
def event_loop() -> asyncio.AbstractEventLoop:
    loop = asyncio.get_event_loop_policy().new_event_loop()
    yield loop
    loop.close()


@pytest.fixture(scope="session")
def engine():
    engine = create_async_engine(settings.TEST_SQLALCHEMY_DATABASE_URL)
    yield engine
    engine.sync_engine.dispose()


@pytest_asyncio.fixture(scope="session")
async def prepare_db():
    create_db_engine = create_async_engine(
        settings.POSTGRES_DATABASE_URL,
        isolation_level="AUTOCOMMIT",
    )
    async with create_db_engine.begin() as connection:
        await connection.execute(
            text(
                "drop database if exists {name};".format(
                    name=settings.TEST_DB_NAME
                )
            ),
        )
        await connection.execute(
            text("create database {name};".format(name=settings.TEST_DB_NAME)),
        )


@pytest_asyncio.fixture(scope="session")
async def db_session(engine) -> AsyncSession:
    async with engine.begin() as connection:
        await connection.run_sync(Base.metadata.drop_all)
        await connection.run_sync(Base.metadata.create_all)
        TestingSessionLocal = sessionmaker(
            expire_on_commit=False,
            class_=AsyncSession,
            bind=engine,
        )
        async with TestingSessionLocal(bind=connection) as session:
            yield session
            await session.flush()
            await session.rollback()


@pytest.fixture(scope="session")
def override_get_db(prepare_db, db_session: AsyncSession):
    async def _override_get_db():
        yield db_session

    return _override_get_db


@pytest_asyncio.fixture(scope="session")
async def async_client(override_get_db):
    app.dependency_overrides[get_db] = override_get_db
    async with AsyncClient(app=app, base_url="http://test") as ac:
        yield ac

Прежде всего, нам нужно изменить область видимости для фикстуры event_loop. По умолчанию это фикстура, скопированная на функцию, но в данном случае мы должны сделать нашу БД также скопированной на функцию, что приведет к проблемам с производительностью, а скоуп сессии исправит это.

Также мы добавили фикстуру движка для использования тестовой базы данных вместо реальной базы данных. В prepare_db мы убеждаемся, что БД создана. В db_session мы создаем вкладки и возвращаем соединение с БД. А в override_get_db мы обновляем зависимости проекта, чтобы убедиться, что представления во время тестирования не будут использовать реальную базу данных. Наконец, мы создали async_client для выполнения асинхронных запросов к нашему API.

Все приготовления сделаны, и вот наши тесты:

import pytest
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession

from app.crud.crud_user import create_user
from app.db.base import User
from app.models.users import UserToken
from app.schemas.user import UserCreate

@pytest_asyncio.fixture
async def user(db_session: AsyncSession) -> User:
    user = UserCreate(
      email="nanny_ogg@lancre.com", 
      name="Gytha Ogg", 
      password="12345678"
    )
    user_db = await create_user(db_session, user)
    yield user_db
    await db_session.delete(user_db)
    await db_session.commit()


@pytest.mark.asyncio
async def test_sign_up(async_client, db_session):
    request_data = {
        "email": "sam_vimes@citywatch.com",
        "name": "Sam Vimes",
        "password": "12345678",
    }
    response = await async_client.post("/sign-up/", json=request_data)
    token_counts = await db_session.execute(select(func.count(UserToken.id)))
    assert token_counts.scalar_one() == 1
    assert response.status_code == 200
    assert response.json()["id"] is not None
    assert response.json()["email"] == "sam_vimes@citywatch.com"
    assert response.json()["name"] == "Sam Vimes"
    assert response.json()["token"]["access_token"] is not None
    assert response.json()["token"]["expires"] is not None
    assert response.json()["token"]["token_type"] == "bearer"


@pytest.mark.asyncio
async def test_sign_up_existing_user(async_client, user):
    request_data = {
        "email": user.email,
        "name": "Esme Weatherwax",
        "password": "12345678",
    }
    response = await async_client.post("/sign-up/", json=request_data)
    assert response.status_code == 400
    assert response.json()["detail"] == "User already registered"


@pytest.mark.asyncio
async def test_sign_up_weak_password(async_client):
    request_data = {
        "email": "sam_vimes@citywatch.com",
        "name": "Sam Vimes",
        "password": "123",
    }
    response = await async_client.post("/sign-up/", json=request_data)
    assert response.status_code == 422
    assert (
        response.json()["detail"][0]["msg"]
        == "ensure this value has at least 8 characters"
    )
    assert (
        response.json()["detail"][0]["type"]
        == "value_error.any_str.min_length"
    )

Celery для задач, привязанных к процессору

AsyncIO хорошо подходит для задач, связанных с IO. Именно поэтому мы используем его для чтения данных из базы данных. Но что, если нам нужно выполнить какую-то тяжелую задачу, требующую процессора? В этом случае нам следует подумать о том, чтобы отправить эту задачу в отдельный процесс. И Celery поможет нам в этом.

В нашей системе пользователи смогут создавать посты. Но содержимое постов будет шифроваться перед сохранением в БД. Шифрование – это задача, требующая большого количества процессорных ресурсов, поэтому нам необходимо использовать celery. Давайте создадим необходимые модели:

from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship

from app.db.base_class import Base


class UserKeys(Base):
    __tablename__ = "user_keys"

    id = Column(Integer, primary_key=True, index=True)
    user_id = Column(
        Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
    )
    public_key = Column(String(2000), nullable=False)
    is_revoked = Column(Boolean, default=False)

    user = relationship("User", back_populates="keys")


class UserGroup(Base):
    __tablename__ = "user_groups"

    id = Column(Integer, primary_key=True, index=True)
    name = Column(String(50))

    users = relationship(
        "User",
        secondary="user_group_association",
        back_populates="groups",
    )
    posts = relationship(
        "Post",
        back_populates="user_group",
        cascade="all, delete-orphan",
    )


class UserGroupAssociation(Base):
    __tablename__ = "user_group_association"

    id = Column(Integer, primary_key=True)
    user_id = Column(Integer, ForeignKey("users.id"))
    group_id = Column(Integer, ForeignKey("user_groups.id"))


class Post(Base):
    __tablename__ = "posts"

    id = Column(Integer, primary_key=True, index=True)
    title = Column(String(100))
    content = Column(Text)
    user_id = Column(
        Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
    )
    group_id = Column(
        Integer,
        ForeignKey("user_groups.id", ondelete='CASCADE'),
        nullable=False,
    )

    author = relationship("User", back_populates="posts")
    user_group = relationship("UserGroup", back_populates="posts")
    keys = relationship(
        "PostKeys",
        back_populates="post",
        cascade="all, delete-orphan",
    )


class PostKeys(Base):
    __tablename__ = "post_keys"

    id = Column(Integer, primary_key=True, index=True)
    post_id = Column(
        Integer,
        ForeignKey("posts.id", ondelete='CASCADE'),
        nullable=False,
    )
    public_key_id = Column(
        Integer, ForeignKey("user_keys.id", ondelete='CASCADE'), nullable=False
    )
    encrypted_key = Column(Text)

    post = relationship("Post", back_populates="keys")
    public_key = relationship("UserKeys")

У каждого пользователя есть своя пара открытый/закрытый ключ. Он загружает открытый ключ на сервер и хранит свой закрытый ключ в секрете. Существуют также группы пользователей. Каждый пользователь может участвовать в различных группах, но каждое сообщение может быть прикреплено только к одной определенной группе. Таким образом, только члены группы смогут прочитать содержимое поста.

При добавлении нового сообщения система генерирует временный ключ, шифрует содержимое сообщения этим ключом и для каждого члена группы шифрует временный ключ открытым ключом пользователя. Когда пользователь получает пост с сервера, он получает зашифрованное содержимое и временный ключ, зашифрованный его открытым ключом. Он может использовать закрытый ключ для расшифровки временного ключа и использовать его для расшифровки содержимого поста. Я знаю, звучит немного безумно, поэтому давайте посмотрим на код.

from pydantic import BaseModel


class PostBase(BaseModel):
    title: str
    content: str
    group_id: int


class PostInDBBase(PostBase):
    id: Optional[int] = None

    class Config:
        orm_mode = True


async def create_post(db: AsyncSession, post: PostBase, author: User) -> Post:
    db_post = Post(
        title=post.title,
        content=post.content,
        group_id=post.group_id,
        author=author,
    )
    db.add(db_post)
    await db.commit()
    await db.refresh(db_post)
    return db_post


@router.post("/posts/", response_model=PostInDBBase, status_code=201)
async def create_post(
    post: PostBase,
    db: DBSession,
    current_user: CurrentUser,
):
    plain_content = post.content
    post.content = ""
    post = await create_post(
        db=db,
        post=post,
        author=current_user,
    )
    encrypt_post_content.delay(post_id=post.id, content=plain_content)
    return post

Это представление, которое получает пост и сохраняет его в БД. Наиболее интересной частью здесь является метод encrypt_post_content.delay(). Это фактически задача celery, которая будет выполняться в отдельном процессе. Вот оно:

import os

from celery import Celery
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from sqlalchemy import create_engine, select, update
from sqlalchemy.orm import sessionmaker

from app.core.crypto_tools import (
    asymmetric_encryption,
    generate_symmetric_key,
    symmetric_encryption,
)
from app.db.base import Post, PostKeys, User, UserGroup, UserKeys
from app.core.config import settings

celery = Celery("secureblogs")
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL")


sync_engine = create_engine(settings.SYNC_SQLALCHEMY_DATABASE_URL, echo=True)
SyncSessionLocal = sessionmaker(
    autocommit=False,
    autoflush=False,
    bind=sync_engine,
)


@celery.task(name="encrypt_post_content")
def encrypt_post_content(post_id: int, content: str):
    # generate temp key and encrypt content
    key = generate_symmetric_key()
    encrypted_content = symmetric_encryption(content, key)
    with SyncSessionLocal() as db:
        # update post instance
        post_statement = (
            update(Post)
            .returning(Post.group_id)
            .where(Post.id == post_id)
            .values(content=encrypted_content)
        )
        post = db.execute(post_statement).fetchone()

        # fetch user's public keys from DB
        users_subquery = (
            select(User.id)
            .where(User.groups.any(UserGroup.id.in_([post.group_id])))
            .subquery()
        )
        statement = select(UserKeys).where(
            (UserKeys.user_id.in_(users_subquery))
            & (UserKeys.is_revoked == False)
        )
        public_keys = db.execute(statement).scalars().all()
        db_post_keys = []
        for public_key in public_keys:
            # Save generated keys in DB
            public_pem_data = public_key.public_key
            public_key_object = load_pem_public_key(public_pem_data.encode())
            encrypted_key = asymmetric_encryption(key, public_key_object)
            db_post_keys.append(
                PostKeys(
                    post_id=post_id,
                    public_key_id=public_key.id,
                    encrypted_key=encrypted_key,
                )
            )
        db.bulk_save_objects(db_post_keys)
        db.commit()

Задача Celery – это просто функция sync python, поэтому для выполнения запросов к БД мы используем сессию sync базы данных внутри нее.

Websockets

Фух, похоже, мы сделали это! Наш API позволяет создавать зашифрованные посты. Но подождите, это работает только для уже существующих пользователей. Что если новый пользователь присоединится к группе и захочет прочитать какой-нибудь пост? Он не сможет этого сделать, потому что не сможет расшифровать временный ключ. Ему нужно, чтобы тот, кто создал пост, прислал ему временный ключ. И здесь очень пригодятся вебсокеты. Когда пользователь запрашивает доступ к посту, мы отправляем автору поста уведомление в реальном времени. Автор поста получает уведомление и решает, одобрить запрос или отклонить его. Итак, давайте попробуем это реализовать. Прежде всего, добавьте новую модель. Она содержит информацию о пользователе, который запрашивает доступ к посту, сам пост и открытый ключ пользователя:

class ReadPostRequest(Base):
    __tablename__ = "read_post_request"

    id = Column(Integer, primary_key=True, index=True)
    post_id = Column(
        Integer,
        ForeignKey("posts.id", ondelete='CASCADE'),
        nullable=False,
    )
    user_id = Column(
        Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
    )
    public_key_id = Column(
        Integer, ForeignKey("user_keys.id", ondelete='CASCADE'), nullable=False
    )

    post = relationship("Post")
    requester = relationship("User")
    public_key = relationship("UserKeys")

Теперь нам нужна конечная точка, которая позволяет создать новый запрос:

async def get_post(db: AsyncSession, post_id: int) -> Post:
    statement = select(Post).where(Post.id == post_id)
    result = await db.execute(statement)
    return result.scalars().first()

async def get_user_key(
    db: AsyncSession,
    user: User,
) -> UserKeys:
    statement = select(UserKeys).where(
        (UserKeys.user == user) & (UserKeys.is_revoked == False)
    )
    result = await db.execute(statement)
    return result.scalars().first()

async def add_read_post_request(
    db: AsyncSession, user: User, post_id: int
) -> ReadPostRequest:
    exists_statement = select(ReadPostRequest.id).where(
        (ReadPostRequest.user_id == user.id)
        & (ReadPostRequest.post_id == post_id)
    )
    result = await db.execute(exists_statement)
    if result.scalars().first():
        return None
    public_key_statement = select(UserKeys).where(
        (UserKeys.is_revoked == False) & (UserKeys.user_id == user.id)
    )
    result = await db.execute(public_key_statement)
    if not (public_key := result.scalars().first()):
        return None
    db_read_post_request = ReadPostRequest(
        user_id=user.id,
        post_id=post_id,
        public_key=public_key,
    )
    db.add(db_read_post_request)
    await db.commit()
    await db.refresh(db_read_post_request)
    return db_read_post_request

@router.post("/posts/{post_id}/request_read/", status_code=204)
async def add_read_post_request(
    post_id: int,
    db: DBSession,
    current_user: CurrentUser,
):
    post = await crud_post.get_post(db, post_id)
    if not post:
        raise HTTPException(status_code=404)
    user_key = await crud_user.get_user_key(db, current_user)
    request = await crud_post.add_read_post_request(db, current_user, post_id)
    if request:
        await ws_manager.send_personal_message(
            {
                'request_id': request.id,
                'post_id': post_id,
                'requested_user_id': current_user.id,
                'user_public_key': user_key.public_key,
            },
            post.user_id,
        )

Здесь нет ничего нового, кроме последней строки. Мы просто проверяем, действительно ли почта существует в БД. Затем мы получаем открытый ключ пользователя, создаем новый запрос и, наконец, отправляем уведомление по websocket. Давайте посмотрим поближе, как мы это делаем.

from fastapi import WebSocket


class ConnectionManager:
    def __init__(self):
        self.active_connections: dict[int, WebSocket] = {}

    async def connect(self, user_id: int, websocket: WebSocket):
        await websocket.accept()
        self.active_connections[user_id] = websocket

    def disconnect(self, user_id: int):
        self.active_connections.pop(user_id)

    async def send_personal_message(self, message: dict, user_id: int):
        if websocket := self.active_connections.get(user_id):
            await websocket.send_json(message)



ws_manager = ConnectionManager()

Это наш менеджер websocket. Здесь у нас есть словарь, в котором хранятся идентификаторы пользователей и соединения websocket, связанные с каждым из них. Когда кто-то хочет отправить личное уведомление, он использует send_personal_message .

Наконец, давайте проверим, как создать новое соединение websocket.

from typing import Annotated

from fastapi import (
    APIRouter,
    Depends,
    Query,
    status,
    WebSocket,
    WebSocketDisconnect,
    WebSocketException,
)

from app.api.deps import DBSession
from app.api.websockets.managers import ws_manager
from app.crud.crud_user import get_user_by_token


router = APIRouter()


async def get_token(
    websocket: WebSocket,
    token: Annotated[str | None, Query()] = None,
):
    if token is None:
        raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
    return token


@router.websocket("/ws/post_request")
async def websocket_endpoint(
    websocket: WebSocket,
    db: DBSession,
    token: Annotated[str, Depends(get_token)],
):
    user = await get_user_by_token(db, token)
    if not user:
        raise WebSocketException(code=status.HTTP_401_UNAUTHORIZED)
    try:
        await ws_manager.connect(user.id, websocket)
        await ws_manager.send_personal_message(
            {"message": "connection accepted"},
            user.id,
        )
        while True:
            await websocket.receive_text()
    except WebSocketDisconnect:
        ws_manager.disconnect(user.id)

Мы создали новую конечную точку websocket /ws/post_request . Эта конечная точка проверяет токен пользователя, и если токен действителен, она создает новое соединение и отправляет пользователю подтверждающее сообщение о том, что соединение принято.

Заключение

Ну, вот и все. Надеюсь, этот пост был полезен.

+1
0
+1
1
+1
0
+1
0
+1
0

Ответить

Ваш адрес email не будет опубликован. Обязательные поля помечены *