FastAPI с асинхронным SQLAlchemy, celery и websockets
Начиная с версии 1.4 SQLAlchemy поддерживает asyncio. В этом руководстве мы попытаемся реализовать простой проект с использованием асинхронной функции SQLAlchemy, шифрования, celery и websocket. Но прежде всего, давайте начнем с подключения к базе данных.
Настройка БД с помощью асинхронного SQLAlchemy
Прежде всего, давайте создадим асинхронную сессию:
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.core.config import settings
engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URL, echo=True)
SessionLocal = sessionmaker(
expire_on_commit=False,
class_=AsyncSession,
bind=engine,
)
Теперь мы можем внедрить эту сессию в наше представление с помощью зависимостей FastAPI:
async def get_db() -> AsyncSession:
async with SessionLocal() as session:
yield session
Все готово, поэтому пришло время использовать DB. В нашем проекте мы будем использовать простую аутентификацию с помощью токенов, поэтому нам нужны две таблицы DB: users и user_tokens
from sqlalchemy.orm import declarative_base
from sqlalchemy_utils import EmailType, force_auto_coercion, PasswordType
Base = declarative_base()
force_auto_coercion()
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50))
email = Column(EmailType(50), unique=True, nullable=False)
password = Column(PasswordType(schemes=["pbkdf2_sha512"]), nullable=False)
tokens = relationship(
"UserToken",
back_populates="user",
lazy='dynamic',
cascade="all, delete-orphan",
)
class UserToken(Base):
__tablename__ = "user_tokens"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(
Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
)
token = Column(
UUID(as_uuid=True), unique=True, nullable=False, default=uuid.uuid4
)
expires = Column(DateTime)
user = relationship("User", back_populates="tokens", lazy='joined')Note `force_auto_coercion()`
Обратите внимание, что мы используем force_auto_coercion() перед моделями. Это помогает убедиться, что пароли хэшируются перед сохранением записи в БД.
На этом этапе хорошо бы добавить какой-нибудь инструмент миграции для обработки обновлений метаданных базы данных. Мы будем использовать alembic для этой цели.
Установите alembic:
pip install alembic
И инициализируйте его:
alembic init migrations
Вышеприведенная команда создаст каталог migrations с файлами env.py, README и script.py.mako.
Чтобы заставить alembic работать с нашей базой данных, нам нужно обновить файл env.py:
import asyncio
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import AsyncEngine
from alembic import context
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Here we importing and specifying our DB metadata
from app.db.base import Base
target_metadata = Base.metadata
# This method returns url of our DB
def get_url():
return os.getenv("SQLALCHEMY_DATABASE_URL", "")
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
# Specify which database we use with alembic
url = get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = get_url()
connectable = AsyncEngine(
engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
future=True,
)
)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
if context.is_offline_mode():
run_migrations_offline()
else:
asyncio.run(run_migrations_online())
Для создания миграции просто выполните следующую команду:
alembic revision --autogenerate -m "Added required tables"
Чтобы применить миграции и обновить базу данных, выполните следующие действия:
alembic upgrade head
Теперь наша база данных готова, и мы можем попробовать создать нового пользователя и токен:
from sqlalchemy import select
from app.db.base import User
async def get_user_by_email(db: AsyncSession, email: str) -> User:
statement = select(User).where(User.email == email)
result = await db.execute(statement)
return result.scalars().first()
async def create_user(db: AsyncSession, user: UserCreate) -> User:
db_user = User(
email=user.email,
name=user.name,
password=user.password,
)
db.add(db_user)
await db.commit()
await db.refresh(db_user)
return db_user
async def create_user_token(db: AsyncSession, user: User) -> UserToken:
db_token = UserToken(
user=user, expires=datetime.now() + timedelta(weeks=2)
)
db.add(db_token)
await db.commit()
return db_token
И используйте этот код для регистрации нового пользователя:
from fastapi import APIRouter, FastAPI
from pydantic import BaseModel
from app.crud import crud_user
app = FastAPI()
router = APIRouter()
class UserBase(BaseModel):
email: EmailStr
name: str
class UserCreate(UserBase):
password: constr(strip_whitespace=True, min_length=8)
class User(UserBase):
id: Optional[int] = None
token: TokenBase | None = None
class Config:
orm_mode = True
@router.post("/sign-up/", response_model=User)
async def create_user(user: UserCreate, db: DBSession):
user_db = await crud_user.get_user_by_email(db, email=user.email)
if user_db:
raise HTTPException(status_code=400, detail="User already registered")
user = await crud_user.create_user(db, user=user)
user.token = await crud_user.create_user_token(db, user=user)
return user
app.include_router(user_routes)
Добавить тесты
Ура! Мы реализовали логику регистрации, и было бы неплохо добавить несколько тестов, чтобы проверить, что все работает так, как ожидалось. Поскольку мы используем асинхронное подключение к БД, нам понадобится асинхронный тест. Поэтому давайте добавим несколько специальных приспособлений, которые нам помогут:
import asyncio
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.api.deps import get_db
from app.core.config import settings
from app.db.base import Base
from app.main import app
@pytest.fixture(scope="session")
def event_loop() -> asyncio.AbstractEventLoop:
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
def engine():
engine = create_async_engine(settings.TEST_SQLALCHEMY_DATABASE_URL)
yield engine
engine.sync_engine.dispose()
@pytest_asyncio.fixture(scope="session")
async def prepare_db():
create_db_engine = create_async_engine(
settings.POSTGRES_DATABASE_URL,
isolation_level="AUTOCOMMIT",
)
async with create_db_engine.begin() as connection:
await connection.execute(
text(
"drop database if exists {name};".format(
name=settings.TEST_DB_NAME
)
),
)
await connection.execute(
text("create database {name};".format(name=settings.TEST_DB_NAME)),
)
@pytest_asyncio.fixture(scope="session")
async def db_session(engine) -> AsyncSession:
async with engine.begin() as connection:
await connection.run_sync(Base.metadata.drop_all)
await connection.run_sync(Base.metadata.create_all)
TestingSessionLocal = sessionmaker(
expire_on_commit=False,
class_=AsyncSession,
bind=engine,
)
async with TestingSessionLocal(bind=connection) as session:
yield session
await session.flush()
await session.rollback()
@pytest.fixture(scope="session")
def override_get_db(prepare_db, db_session: AsyncSession):
async def _override_get_db():
yield db_session
return _override_get_db
@pytest_asyncio.fixture(scope="session")
async def async_client(override_get_db):
app.dependency_overrides[get_db] = override_get_db
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
Прежде всего, нам нужно изменить область видимости для фикстуры event_loop. По умолчанию это фикстура, скопированная на функцию, но в данном случае мы должны сделать нашу БД также скопированной на функцию, что приведет к проблемам с производительностью, а скоуп сессии исправит это.
Также мы добавили фикстуру движка для использования тестовой базы данных вместо реальной базы данных. В prepare_db мы убеждаемся, что БД создана. В db_session мы создаем вкладки и возвращаем соединение с БД. А в override_get_db мы обновляем зависимости проекта, чтобы убедиться, что представления во время тестирования не будут использовать реальную базу данных. Наконец, мы создали async_client для выполнения асинхронных запросов к нашему API.
Все приготовления сделаны, и вот наши тесты:
import pytest
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.crud.crud_user import create_user
from app.db.base import User
from app.models.users import UserToken
from app.schemas.user import UserCreate
@pytest_asyncio.fixture
async def user(db_session: AsyncSession) -> User:
user = UserCreate(
email="nanny_ogg@lancre.com",
name="Gytha Ogg",
password="12345678"
)
user_db = await create_user(db_session, user)
yield user_db
await db_session.delete(user_db)
await db_session.commit()
@pytest.mark.asyncio
async def test_sign_up(async_client, db_session):
request_data = {
"email": "sam_vimes@citywatch.com",
"name": "Sam Vimes",
"password": "12345678",
}
response = await async_client.post("/sign-up/", json=request_data)
token_counts = await db_session.execute(select(func.count(UserToken.id)))
assert token_counts.scalar_one() == 1
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["email"] == "sam_vimes@citywatch.com"
assert response.json()["name"] == "Sam Vimes"
assert response.json()["token"]["access_token"] is not None
assert response.json()["token"]["expires"] is not None
assert response.json()["token"]["token_type"] == "bearer"
@pytest.mark.asyncio
async def test_sign_up_existing_user(async_client, user):
request_data = {
"email": user.email,
"name": "Esme Weatherwax",
"password": "12345678",
}
response = await async_client.post("/sign-up/", json=request_data)
assert response.status_code == 400
assert response.json()["detail"] == "User already registered"
@pytest.mark.asyncio
async def test_sign_up_weak_password(async_client):
request_data = {
"email": "sam_vimes@citywatch.com",
"name": "Sam Vimes",
"password": "123",
}
response = await async_client.post("/sign-up/", json=request_data)
assert response.status_code == 422
assert (
response.json()["detail"][0]["msg"]
== "ensure this value has at least 8 characters"
)
assert (
response.json()["detail"][0]["type"]
== "value_error.any_str.min_length"
)
Celery для задач, привязанных к процессору
AsyncIO хорошо подходит для задач, связанных с IO. Именно поэтому мы используем его для чтения данных из базы данных. Но что, если нам нужно выполнить какую-то тяжелую задачу, требующую процессора? В этом случае нам следует подумать о том, чтобы отправить эту задачу в отдельный процесс. И Celery поможет нам в этом.
В нашей системе пользователи смогут создавать посты. Но содержимое постов будет шифроваться перед сохранением в БД. Шифрование – это задача, требующая большого количества процессорных ресурсов, поэтому нам необходимо использовать celery. Давайте создадим необходимые модели:
from sqlalchemy import Column, ForeignKey, Integer, String, Text
from sqlalchemy.orm import relationship
from app.db.base_class import Base
class UserKeys(Base):
__tablename__ = "user_keys"
id = Column(Integer, primary_key=True, index=True)
user_id = Column(
Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
)
public_key = Column(String(2000), nullable=False)
is_revoked = Column(Boolean, default=False)
user = relationship("User", back_populates="keys")
class UserGroup(Base):
__tablename__ = "user_groups"
id = Column(Integer, primary_key=True, index=True)
name = Column(String(50))
users = relationship(
"User",
secondary="user_group_association",
back_populates="groups",
)
posts = relationship(
"Post",
back_populates="user_group",
cascade="all, delete-orphan",
)
class UserGroupAssociation(Base):
__tablename__ = "user_group_association"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
group_id = Column(Integer, ForeignKey("user_groups.id"))
class Post(Base):
__tablename__ = "posts"
id = Column(Integer, primary_key=True, index=True)
title = Column(String(100))
content = Column(Text)
user_id = Column(
Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
)
group_id = Column(
Integer,
ForeignKey("user_groups.id", ondelete='CASCADE'),
nullable=False,
)
author = relationship("User", back_populates="posts")
user_group = relationship("UserGroup", back_populates="posts")
keys = relationship(
"PostKeys",
back_populates="post",
cascade="all, delete-orphan",
)
class PostKeys(Base):
__tablename__ = "post_keys"
id = Column(Integer, primary_key=True, index=True)
post_id = Column(
Integer,
ForeignKey("posts.id", ondelete='CASCADE'),
nullable=False,
)
public_key_id = Column(
Integer, ForeignKey("user_keys.id", ondelete='CASCADE'), nullable=False
)
encrypted_key = Column(Text)
post = relationship("Post", back_populates="keys")
public_key = relationship("UserKeys")
У каждого пользователя есть своя пара открытый/закрытый ключ. Он загружает открытый ключ на сервер и хранит свой закрытый ключ в секрете. Существуют также группы пользователей. Каждый пользователь может участвовать в различных группах, но каждое сообщение может быть прикреплено только к одной определенной группе. Таким образом, только члены группы смогут прочитать содержимое поста.
При добавлении нового сообщения система генерирует временный ключ, шифрует содержимое сообщения этим ключом и для каждого члена группы шифрует временный ключ открытым ключом пользователя. Когда пользователь получает пост с сервера, он получает зашифрованное содержимое и временный ключ, зашифрованный его открытым ключом. Он может использовать закрытый ключ для расшифровки временного ключа и использовать его для расшифровки содержимого поста. Я знаю, звучит немного безумно, поэтому давайте посмотрим на код.
from pydantic import BaseModel
class PostBase(BaseModel):
title: str
content: str
group_id: int
class PostInDBBase(PostBase):
id: Optional[int] = None
class Config:
orm_mode = True
async def create_post(db: AsyncSession, post: PostBase, author: User) -> Post:
db_post = Post(
title=post.title,
content=post.content,
group_id=post.group_id,
author=author,
)
db.add(db_post)
await db.commit()
await db.refresh(db_post)
return db_post
@router.post("/posts/", response_model=PostInDBBase, status_code=201)
async def create_post(
post: PostBase,
db: DBSession,
current_user: CurrentUser,
):
plain_content = post.content
post.content = ""
post = await create_post(
db=db,
post=post,
author=current_user,
)
encrypt_post_content.delay(post_id=post.id, content=plain_content)
return post
Это представление, которое получает пост и сохраняет его в БД. Наиболее интересной частью здесь является метод encrypt_post_content.delay(). Это фактически задача celery, которая будет выполняться в отдельном процессе. Вот оно:
import os
from celery import Celery
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from sqlalchemy import create_engine, select, update
from sqlalchemy.orm import sessionmaker
from app.core.crypto_tools import (
asymmetric_encryption,
generate_symmetric_key,
symmetric_encryption,
)
from app.db.base import Post, PostKeys, User, UserGroup, UserKeys
from app.core.config import settings
celery = Celery("secureblogs")
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL")
sync_engine = create_engine(settings.SYNC_SQLALCHEMY_DATABASE_URL, echo=True)
SyncSessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=sync_engine,
)
@celery.task(name="encrypt_post_content")
def encrypt_post_content(post_id: int, content: str):
# generate temp key and encrypt content
key = generate_symmetric_key()
encrypted_content = symmetric_encryption(content, key)
with SyncSessionLocal() as db:
# update post instance
post_statement = (
update(Post)
.returning(Post.group_id)
.where(Post.id == post_id)
.values(content=encrypted_content)
)
post = db.execute(post_statement).fetchone()
# fetch user's public keys from DB
users_subquery = (
select(User.id)
.where(User.groups.any(UserGroup.id.in_([post.group_id])))
.subquery()
)
statement = select(UserKeys).where(
(UserKeys.user_id.in_(users_subquery))
& (UserKeys.is_revoked == False)
)
public_keys = db.execute(statement).scalars().all()
db_post_keys = []
for public_key in public_keys:
# Save generated keys in DB
public_pem_data = public_key.public_key
public_key_object = load_pem_public_key(public_pem_data.encode())
encrypted_key = asymmetric_encryption(key, public_key_object)
db_post_keys.append(
PostKeys(
post_id=post_id,
public_key_id=public_key.id,
encrypted_key=encrypted_key,
)
)
db.bulk_save_objects(db_post_keys)
db.commit()
Задача Celery – это просто функция sync python, поэтому для выполнения запросов к БД мы используем сессию sync базы данных внутри нее.
Websockets
Фух, похоже, мы сделали это! Наш API позволяет создавать зашифрованные посты. Но подождите, это работает только для уже существующих пользователей. Что если новый пользователь присоединится к группе и захочет прочитать какой-нибудь пост? Он не сможет этого сделать, потому что не сможет расшифровать временный ключ. Ему нужно, чтобы тот, кто создал пост, прислал ему временный ключ. И здесь очень пригодятся вебсокеты. Когда пользователь запрашивает доступ к посту, мы отправляем автору поста уведомление в реальном времени. Автор поста получает уведомление и решает, одобрить запрос или отклонить его. Итак, давайте попробуем это реализовать. Прежде всего, добавьте новую модель. Она содержит информацию о пользователе, который запрашивает доступ к посту, сам пост и открытый ключ пользователя:
class ReadPostRequest(Base):
__tablename__ = "read_post_request"
id = Column(Integer, primary_key=True, index=True)
post_id = Column(
Integer,
ForeignKey("posts.id", ondelete='CASCADE'),
nullable=False,
)
user_id = Column(
Integer, ForeignKey("users.id", ondelete='CASCADE'), nullable=False
)
public_key_id = Column(
Integer, ForeignKey("user_keys.id", ondelete='CASCADE'), nullable=False
)
post = relationship("Post")
requester = relationship("User")
public_key = relationship("UserKeys")
Теперь нам нужна конечная точка, которая позволяет создать новый запрос:
async def get_post(db: AsyncSession, post_id: int) -> Post:
statement = select(Post).where(Post.id == post_id)
result = await db.execute(statement)
return result.scalars().first()
async def get_user_key(
db: AsyncSession,
user: User,
) -> UserKeys:
statement = select(UserKeys).where(
(UserKeys.user == user) & (UserKeys.is_revoked == False)
)
result = await db.execute(statement)
return result.scalars().first()
async def add_read_post_request(
db: AsyncSession, user: User, post_id: int
) -> ReadPostRequest:
exists_statement = select(ReadPostRequest.id).where(
(ReadPostRequest.user_id == user.id)
& (ReadPostRequest.post_id == post_id)
)
result = await db.execute(exists_statement)
if result.scalars().first():
return None
public_key_statement = select(UserKeys).where(
(UserKeys.is_revoked == False) & (UserKeys.user_id == user.id)
)
result = await db.execute(public_key_statement)
if not (public_key := result.scalars().first()):
return None
db_read_post_request = ReadPostRequest(
user_id=user.id,
post_id=post_id,
public_key=public_key,
)
db.add(db_read_post_request)
await db.commit()
await db.refresh(db_read_post_request)
return db_read_post_request
@router.post("/posts/{post_id}/request_read/", status_code=204)
async def add_read_post_request(
post_id: int,
db: DBSession,
current_user: CurrentUser,
):
post = await crud_post.get_post(db, post_id)
if not post:
raise HTTPException(status_code=404)
user_key = await crud_user.get_user_key(db, current_user)
request = await crud_post.add_read_post_request(db, current_user, post_id)
if request:
await ws_manager.send_personal_message(
{
'request_id': request.id,
'post_id': post_id,
'requested_user_id': current_user.id,
'user_public_key': user_key.public_key,
},
post.user_id,
)
Здесь нет ничего нового, кроме последней строки. Мы просто проверяем, действительно ли почта существует в БД. Затем мы получаем открытый ключ пользователя, создаем новый запрос и, наконец, отправляем уведомление по websocket. Давайте посмотрим поближе, как мы это делаем.
from fastapi import WebSocket
class ConnectionManager:
def __init__(self):
self.active_connections: dict[int, WebSocket] = {}
async def connect(self, user_id: int, websocket: WebSocket):
await websocket.accept()
self.active_connections[user_id] = websocket
def disconnect(self, user_id: int):
self.active_connections.pop(user_id)
async def send_personal_message(self, message: dict, user_id: int):
if websocket := self.active_connections.get(user_id):
await websocket.send_json(message)
ws_manager = ConnectionManager()
Это наш менеджер websocket. Здесь у нас есть словарь, в котором хранятся идентификаторы пользователей и соединения websocket, связанные с каждым из них. Когда кто-то хочет отправить личное уведомление, он использует send_personal_message .
Наконец, давайте проверим, как создать новое соединение websocket.
from typing import Annotated
from fastapi import (
APIRouter,
Depends,
Query,
status,
WebSocket,
WebSocketDisconnect,
WebSocketException,
)
from app.api.deps import DBSession
from app.api.websockets.managers import ws_manager
from app.crud.crud_user import get_user_by_token
router = APIRouter()
async def get_token(
websocket: WebSocket,
token: Annotated[str | None, Query()] = None,
):
if token is None:
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
return token
@router.websocket("/ws/post_request")
async def websocket_endpoint(
websocket: WebSocket,
db: DBSession,
token: Annotated[str, Depends(get_token)],
):
user = await get_user_by_token(db, token)
if not user:
raise WebSocketException(code=status.HTTP_401_UNAUTHORIZED)
try:
await ws_manager.connect(user.id, websocket)
await ws_manager.send_personal_message(
{"message": "connection accepted"},
user.id,
)
while True:
await websocket.receive_text()
except WebSocketDisconnect:
ws_manager.disconnect(user.id)
Мы создали новую конечную точку websocket /ws/post_request . Эта конечная точка проверяет токен пользователя, и если токен действителен, она создает новое соединение и отправляет пользователю подтверждающее сообщение о том, что соединение принято.
Заключение
Ну, вот и все. Надеюсь, этот пост был полезен.